1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * common code for virtio vsock 4 * 5 * Copyright (C) 2013-2015 Red Hat, Inc. 6 * Author: Asias He <asias@redhat.com> 7 * Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 #include <linux/spinlock.h> 10 #include <linux/module.h> 11 #include <linux/sched/signal.h> 12 #include <linux/ctype.h> 13 #include <linux/list.h> 14 #include <linux/virtio.h> 15 #include <linux/virtio_ids.h> 16 #include <linux/virtio_config.h> 17 #include <linux/virtio_vsock.h> 18 #include <uapi/linux/vsockmon.h> 19 20 #include <net/sock.h> 21 #include <net/af_vsock.h> 22 23 #define CREATE_TRACE_POINTS 24 #include <trace/events/vsock_virtio_transport_common.h> 25 26 /* How long to wait for graceful shutdown of a connection */ 27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 28 29 /* Threshold for detecting small packets to copy */ 30 #define GOOD_COPY_LEN 128 31 32 static const struct virtio_transport * 33 virtio_transport_get_ops(struct vsock_sock *vsk) 34 { 35 const struct vsock_transport *t = vsock_core_get_transport(vsk); 36 37 if (WARN_ON(!t)) 38 return NULL; 39 40 return container_of(t, struct virtio_transport, transport); 41 } 42 43 static struct virtio_vsock_pkt * 44 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, 45 size_t len, 46 u32 src_cid, 47 u32 src_port, 48 u32 dst_cid, 49 u32 dst_port) 50 { 51 struct virtio_vsock_pkt *pkt; 52 int err; 53 54 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 55 if (!pkt) 56 return NULL; 57 58 pkt->hdr.type = cpu_to_le16(info->type); 59 pkt->hdr.op = cpu_to_le16(info->op); 60 pkt->hdr.src_cid = cpu_to_le64(src_cid); 61 pkt->hdr.dst_cid = cpu_to_le64(dst_cid); 62 pkt->hdr.src_port = cpu_to_le32(src_port); 63 pkt->hdr.dst_port = cpu_to_le32(dst_port); 64 pkt->hdr.flags = cpu_to_le32(info->flags); 65 pkt->len = len; 66 pkt->hdr.len = cpu_to_le32(len); 67 pkt->reply = info->reply; 68 pkt->vsk = info->vsk; 69 70 if (info->msg && len > 0) { 71 pkt->buf = kmalloc(len, GFP_KERNEL); 72 if (!pkt->buf) 73 goto out_pkt; 74 75 pkt->buf_len = len; 76 77 err = memcpy_from_msg(pkt->buf, info->msg, len); 78 if (err) 79 goto out; 80 } 81 82 trace_virtio_transport_alloc_pkt(src_cid, src_port, 83 dst_cid, dst_port, 84 len, 85 info->type, 86 info->op, 87 info->flags); 88 89 return pkt; 90 91 out: 92 kfree(pkt->buf); 93 out_pkt: 94 kfree(pkt); 95 return NULL; 96 } 97 98 /* Packet capture */ 99 static struct sk_buff *virtio_transport_build_skb(void *opaque) 100 { 101 struct virtio_vsock_pkt *pkt = opaque; 102 struct af_vsockmon_hdr *hdr; 103 struct sk_buff *skb; 104 size_t payload_len; 105 void *payload_buf; 106 107 /* A packet could be split to fit the RX buffer, so we can retrieve 108 * the payload length from the header and the buffer pointer taking 109 * care of the offset in the original packet. 110 */ 111 payload_len = le32_to_cpu(pkt->hdr.len); 112 payload_buf = pkt->buf + pkt->off; 113 114 skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len, 115 GFP_ATOMIC); 116 if (!skb) 117 return NULL; 118 119 hdr = skb_put(skb, sizeof(*hdr)); 120 121 /* pkt->hdr is little-endian so no need to byteswap here */ 122 hdr->src_cid = pkt->hdr.src_cid; 123 hdr->src_port = pkt->hdr.src_port; 124 hdr->dst_cid = pkt->hdr.dst_cid; 125 hdr->dst_port = pkt->hdr.dst_port; 126 127 hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO); 128 hdr->len = cpu_to_le16(sizeof(pkt->hdr)); 129 memset(hdr->reserved, 0, sizeof(hdr->reserved)); 130 131 switch (le16_to_cpu(pkt->hdr.op)) { 132 case VIRTIO_VSOCK_OP_REQUEST: 133 case VIRTIO_VSOCK_OP_RESPONSE: 134 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); 135 break; 136 case VIRTIO_VSOCK_OP_RST: 137 case VIRTIO_VSOCK_OP_SHUTDOWN: 138 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT); 139 break; 140 case VIRTIO_VSOCK_OP_RW: 141 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD); 142 break; 143 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 144 case VIRTIO_VSOCK_OP_CREDIT_REQUEST: 145 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); 146 break; 147 default: 148 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN); 149 break; 150 } 151 152 skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr)); 153 154 if (payload_len) { 155 skb_put_data(skb, payload_buf, payload_len); 156 } 157 158 return skb; 159 } 160 161 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt) 162 { 163 vsock_deliver_tap(virtio_transport_build_skb, pkt); 164 } 165 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); 166 167 /* This function can only be used on connecting/connected sockets, 168 * since a socket assigned to a transport is required. 169 * 170 * Do not use on listener sockets! 171 */ 172 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 173 struct virtio_vsock_pkt_info *info) 174 { 175 u32 src_cid, src_port, dst_cid, dst_port; 176 const struct virtio_transport *t_ops; 177 struct virtio_vsock_sock *vvs; 178 struct virtio_vsock_pkt *pkt; 179 u32 pkt_len = info->pkt_len; 180 181 t_ops = virtio_transport_get_ops(vsk); 182 if (unlikely(!t_ops)) 183 return -EFAULT; 184 185 src_cid = t_ops->transport.get_local_cid(); 186 src_port = vsk->local_addr.svm_port; 187 if (!info->remote_cid) { 188 dst_cid = vsk->remote_addr.svm_cid; 189 dst_port = vsk->remote_addr.svm_port; 190 } else { 191 dst_cid = info->remote_cid; 192 dst_port = info->remote_port; 193 } 194 195 vvs = vsk->trans; 196 197 /* we can send less than pkt_len bytes */ 198 if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) 199 pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE; 200 201 /* virtio_transport_get_credit might return less than pkt_len credit */ 202 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 203 204 /* Do not send zero length OP_RW pkt */ 205 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 206 return pkt_len; 207 208 pkt = virtio_transport_alloc_pkt(info, pkt_len, 209 src_cid, src_port, 210 dst_cid, dst_port); 211 if (!pkt) { 212 virtio_transport_put_credit(vvs, pkt_len); 213 return -ENOMEM; 214 } 215 216 virtio_transport_inc_tx_pkt(vvs, pkt); 217 218 return t_ops->send_pkt(pkt); 219 } 220 221 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 222 struct virtio_vsock_pkt *pkt) 223 { 224 if (vvs->rx_bytes + pkt->len > vvs->buf_alloc) 225 return false; 226 227 vvs->rx_bytes += pkt->len; 228 return true; 229 } 230 231 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 232 struct virtio_vsock_pkt *pkt) 233 { 234 vvs->rx_bytes -= pkt->len; 235 vvs->fwd_cnt += pkt->len; 236 } 237 238 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) 239 { 240 spin_lock_bh(&vvs->rx_lock); 241 vvs->last_fwd_cnt = vvs->fwd_cnt; 242 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 243 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); 244 spin_unlock_bh(&vvs->rx_lock); 245 } 246 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 247 248 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 249 { 250 u32 ret; 251 252 spin_lock_bh(&vvs->tx_lock); 253 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 254 if (ret > credit) 255 ret = credit; 256 vvs->tx_cnt += ret; 257 spin_unlock_bh(&vvs->tx_lock); 258 259 return ret; 260 } 261 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 262 263 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 264 { 265 spin_lock_bh(&vvs->tx_lock); 266 vvs->tx_cnt -= credit; 267 spin_unlock_bh(&vvs->tx_lock); 268 } 269 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 270 271 static int virtio_transport_send_credit_update(struct vsock_sock *vsk, 272 int type, 273 struct virtio_vsock_hdr *hdr) 274 { 275 struct virtio_vsock_pkt_info info = { 276 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 277 .type = type, 278 .vsk = vsk, 279 }; 280 281 return virtio_transport_send_pkt_info(vsk, &info); 282 } 283 284 static ssize_t 285 virtio_transport_stream_do_peek(struct vsock_sock *vsk, 286 struct msghdr *msg, 287 size_t len) 288 { 289 struct virtio_vsock_sock *vvs = vsk->trans; 290 struct virtio_vsock_pkt *pkt; 291 size_t bytes, total = 0, off; 292 int err = -EFAULT; 293 294 spin_lock_bh(&vvs->rx_lock); 295 296 list_for_each_entry(pkt, &vvs->rx_queue, list) { 297 off = pkt->off; 298 299 if (total == len) 300 break; 301 302 while (total < len && off < pkt->len) { 303 bytes = len - total; 304 if (bytes > pkt->len - off) 305 bytes = pkt->len - off; 306 307 /* sk_lock is held by caller so no one else can dequeue. 308 * Unlock rx_lock since memcpy_to_msg() may sleep. 309 */ 310 spin_unlock_bh(&vvs->rx_lock); 311 312 err = memcpy_to_msg(msg, pkt->buf + off, bytes); 313 if (err) 314 goto out; 315 316 spin_lock_bh(&vvs->rx_lock); 317 318 total += bytes; 319 off += bytes; 320 } 321 } 322 323 spin_unlock_bh(&vvs->rx_lock); 324 325 return total; 326 327 out: 328 if (total) 329 err = total; 330 return err; 331 } 332 333 static ssize_t 334 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 335 struct msghdr *msg, 336 size_t len) 337 { 338 struct virtio_vsock_sock *vvs = vsk->trans; 339 struct virtio_vsock_pkt *pkt; 340 size_t bytes, total = 0; 341 u32 free_space; 342 int err = -EFAULT; 343 344 spin_lock_bh(&vvs->rx_lock); 345 while (total < len && !list_empty(&vvs->rx_queue)) { 346 pkt = list_first_entry(&vvs->rx_queue, 347 struct virtio_vsock_pkt, list); 348 349 bytes = len - total; 350 if (bytes > pkt->len - pkt->off) 351 bytes = pkt->len - pkt->off; 352 353 /* sk_lock is held by caller so no one else can dequeue. 354 * Unlock rx_lock since memcpy_to_msg() may sleep. 355 */ 356 spin_unlock_bh(&vvs->rx_lock); 357 358 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); 359 if (err) 360 goto out; 361 362 spin_lock_bh(&vvs->rx_lock); 363 364 total += bytes; 365 pkt->off += bytes; 366 if (pkt->off == pkt->len) { 367 virtio_transport_dec_rx_pkt(vvs, pkt); 368 list_del(&pkt->list); 369 virtio_transport_free_pkt(pkt); 370 } 371 } 372 373 free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt); 374 375 spin_unlock_bh(&vvs->rx_lock); 376 377 /* To reduce the number of credit update messages, 378 * don't update credits as long as lots of space is available. 379 * Note: the limit chosen here is arbitrary. Setting the limit 380 * too high causes extra messages. Too low causes transmitter 381 * stalls. As stalls are in theory more expensive than extra 382 * messages, we set the limit to a high value. TODO: experiment 383 * with different values. 384 */ 385 if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { 386 virtio_transport_send_credit_update(vsk, 387 VIRTIO_VSOCK_TYPE_STREAM, 388 NULL); 389 } 390 391 return total; 392 393 out: 394 if (total) 395 err = total; 396 return err; 397 } 398 399 ssize_t 400 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 401 struct msghdr *msg, 402 size_t len, int flags) 403 { 404 if (flags & MSG_PEEK) 405 return virtio_transport_stream_do_peek(vsk, msg, len); 406 else 407 return virtio_transport_stream_do_dequeue(vsk, msg, len); 408 } 409 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 410 411 int 412 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 413 struct msghdr *msg, 414 size_t len, int flags) 415 { 416 return -EOPNOTSUPP; 417 } 418 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 419 420 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 421 { 422 struct virtio_vsock_sock *vvs = vsk->trans; 423 s64 bytes; 424 425 spin_lock_bh(&vvs->rx_lock); 426 bytes = vvs->rx_bytes; 427 spin_unlock_bh(&vvs->rx_lock); 428 429 return bytes; 430 } 431 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 432 433 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 434 { 435 struct virtio_vsock_sock *vvs = vsk->trans; 436 s64 bytes; 437 438 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 439 if (bytes < 0) 440 bytes = 0; 441 442 return bytes; 443 } 444 445 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 446 { 447 struct virtio_vsock_sock *vvs = vsk->trans; 448 s64 bytes; 449 450 spin_lock_bh(&vvs->tx_lock); 451 bytes = virtio_transport_has_space(vsk); 452 spin_unlock_bh(&vvs->tx_lock); 453 454 return bytes; 455 } 456 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 457 458 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 459 struct vsock_sock *psk) 460 { 461 struct virtio_vsock_sock *vvs; 462 463 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 464 if (!vvs) 465 return -ENOMEM; 466 467 vsk->trans = vvs; 468 vvs->vsk = vsk; 469 if (psk && psk->trans) { 470 struct virtio_vsock_sock *ptrans = psk->trans; 471 472 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 473 } 474 475 if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE) 476 vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE; 477 478 vvs->buf_alloc = vsk->buffer_size; 479 480 spin_lock_init(&vvs->rx_lock); 481 spin_lock_init(&vvs->tx_lock); 482 INIT_LIST_HEAD(&vvs->rx_queue); 483 484 return 0; 485 } 486 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 487 488 /* sk_lock held by the caller */ 489 void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) 490 { 491 struct virtio_vsock_sock *vvs = vsk->trans; 492 493 if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE) 494 *val = VIRTIO_VSOCK_MAX_BUF_SIZE; 495 496 vvs->buf_alloc = *val; 497 498 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, 499 NULL); 500 } 501 EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); 502 503 int 504 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 505 size_t target, 506 bool *data_ready_now) 507 { 508 if (vsock_stream_has_data(vsk)) 509 *data_ready_now = true; 510 else 511 *data_ready_now = false; 512 513 return 0; 514 } 515 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 516 517 int 518 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 519 size_t target, 520 bool *space_avail_now) 521 { 522 s64 free_space; 523 524 free_space = vsock_stream_has_space(vsk); 525 if (free_space > 0) 526 *space_avail_now = true; 527 else if (free_space == 0) 528 *space_avail_now = false; 529 530 return 0; 531 } 532 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 533 534 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 535 size_t target, struct vsock_transport_recv_notify_data *data) 536 { 537 return 0; 538 } 539 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 540 541 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 542 size_t target, struct vsock_transport_recv_notify_data *data) 543 { 544 return 0; 545 } 546 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 547 548 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 549 size_t target, struct vsock_transport_recv_notify_data *data) 550 { 551 return 0; 552 } 553 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 554 555 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 556 size_t target, ssize_t copied, bool data_read, 557 struct vsock_transport_recv_notify_data *data) 558 { 559 return 0; 560 } 561 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 562 563 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 564 struct vsock_transport_send_notify_data *data) 565 { 566 return 0; 567 } 568 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 569 570 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 571 struct vsock_transport_send_notify_data *data) 572 { 573 return 0; 574 } 575 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 576 577 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 578 struct vsock_transport_send_notify_data *data) 579 { 580 return 0; 581 } 582 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 583 584 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 585 ssize_t written, struct vsock_transport_send_notify_data *data) 586 { 587 return 0; 588 } 589 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 590 591 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 592 { 593 return vsk->buffer_size; 594 } 595 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 596 597 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 598 { 599 return true; 600 } 601 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 602 603 bool virtio_transport_stream_allow(u32 cid, u32 port) 604 { 605 return true; 606 } 607 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 608 609 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 610 struct sockaddr_vm *addr) 611 { 612 return -EOPNOTSUPP; 613 } 614 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 615 616 bool virtio_transport_dgram_allow(u32 cid, u32 port) 617 { 618 return false; 619 } 620 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 621 622 int virtio_transport_connect(struct vsock_sock *vsk) 623 { 624 struct virtio_vsock_pkt_info info = { 625 .op = VIRTIO_VSOCK_OP_REQUEST, 626 .type = VIRTIO_VSOCK_TYPE_STREAM, 627 .vsk = vsk, 628 }; 629 630 return virtio_transport_send_pkt_info(vsk, &info); 631 } 632 EXPORT_SYMBOL_GPL(virtio_transport_connect); 633 634 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 635 { 636 struct virtio_vsock_pkt_info info = { 637 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 638 .type = VIRTIO_VSOCK_TYPE_STREAM, 639 .flags = (mode & RCV_SHUTDOWN ? 640 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 641 (mode & SEND_SHUTDOWN ? 642 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 643 .vsk = vsk, 644 }; 645 646 return virtio_transport_send_pkt_info(vsk, &info); 647 } 648 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 649 650 int 651 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 652 struct sockaddr_vm *remote_addr, 653 struct msghdr *msg, 654 size_t dgram_len) 655 { 656 return -EOPNOTSUPP; 657 } 658 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 659 660 ssize_t 661 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 662 struct msghdr *msg, 663 size_t len) 664 { 665 struct virtio_vsock_pkt_info info = { 666 .op = VIRTIO_VSOCK_OP_RW, 667 .type = VIRTIO_VSOCK_TYPE_STREAM, 668 .msg = msg, 669 .pkt_len = len, 670 .vsk = vsk, 671 }; 672 673 return virtio_transport_send_pkt_info(vsk, &info); 674 } 675 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 676 677 void virtio_transport_destruct(struct vsock_sock *vsk) 678 { 679 struct virtio_vsock_sock *vvs = vsk->trans; 680 681 kfree(vvs); 682 } 683 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 684 685 static int virtio_transport_reset(struct vsock_sock *vsk, 686 struct virtio_vsock_pkt *pkt) 687 { 688 struct virtio_vsock_pkt_info info = { 689 .op = VIRTIO_VSOCK_OP_RST, 690 .type = VIRTIO_VSOCK_TYPE_STREAM, 691 .reply = !!pkt, 692 .vsk = vsk, 693 }; 694 695 /* Send RST only if the original pkt is not a RST pkt */ 696 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 697 return 0; 698 699 return virtio_transport_send_pkt_info(vsk, &info); 700 } 701 702 /* Normally packets are associated with a socket. There may be no socket if an 703 * attempt was made to connect to a socket that does not exist. 704 */ 705 static int virtio_transport_reset_no_sock(const struct virtio_transport *t, 706 struct virtio_vsock_pkt *pkt) 707 { 708 struct virtio_vsock_pkt *reply; 709 struct virtio_vsock_pkt_info info = { 710 .op = VIRTIO_VSOCK_OP_RST, 711 .type = le16_to_cpu(pkt->hdr.type), 712 .reply = true, 713 }; 714 715 /* Send RST only if the original pkt is not a RST pkt */ 716 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 717 return 0; 718 719 reply = virtio_transport_alloc_pkt(&info, 0, 720 le64_to_cpu(pkt->hdr.dst_cid), 721 le32_to_cpu(pkt->hdr.dst_port), 722 le64_to_cpu(pkt->hdr.src_cid), 723 le32_to_cpu(pkt->hdr.src_port)); 724 if (!reply) 725 return -ENOMEM; 726 727 if (!t) { 728 virtio_transport_free_pkt(reply); 729 return -ENOTCONN; 730 } 731 732 return t->send_pkt(reply); 733 } 734 735 static void virtio_transport_wait_close(struct sock *sk, long timeout) 736 { 737 if (timeout) { 738 DEFINE_WAIT_FUNC(wait, woken_wake_function); 739 740 add_wait_queue(sk_sleep(sk), &wait); 741 742 do { 743 if (sk_wait_event(sk, &timeout, 744 sock_flag(sk, SOCK_DONE), &wait)) 745 break; 746 } while (!signal_pending(current) && timeout); 747 748 remove_wait_queue(sk_sleep(sk), &wait); 749 } 750 } 751 752 static void virtio_transport_do_close(struct vsock_sock *vsk, 753 bool cancel_timeout) 754 { 755 struct sock *sk = sk_vsock(vsk); 756 757 sock_set_flag(sk, SOCK_DONE); 758 vsk->peer_shutdown = SHUTDOWN_MASK; 759 if (vsock_stream_has_data(vsk) <= 0) 760 sk->sk_state = TCP_CLOSING; 761 sk->sk_state_change(sk); 762 763 if (vsk->close_work_scheduled && 764 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 765 vsk->close_work_scheduled = false; 766 767 vsock_remove_sock(vsk); 768 769 /* Release refcnt obtained when we scheduled the timeout */ 770 sock_put(sk); 771 } 772 } 773 774 static void virtio_transport_close_timeout(struct work_struct *work) 775 { 776 struct vsock_sock *vsk = 777 container_of(work, struct vsock_sock, close_work.work); 778 struct sock *sk = sk_vsock(vsk); 779 780 sock_hold(sk); 781 lock_sock(sk); 782 783 if (!sock_flag(sk, SOCK_DONE)) { 784 (void)virtio_transport_reset(vsk, NULL); 785 786 virtio_transport_do_close(vsk, false); 787 } 788 789 vsk->close_work_scheduled = false; 790 791 release_sock(sk); 792 sock_put(sk); 793 } 794 795 /* User context, vsk->sk is locked */ 796 static bool virtio_transport_close(struct vsock_sock *vsk) 797 { 798 struct sock *sk = &vsk->sk; 799 800 if (!(sk->sk_state == TCP_ESTABLISHED || 801 sk->sk_state == TCP_CLOSING)) 802 return true; 803 804 /* Already received SHUTDOWN from peer, reply with RST */ 805 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 806 (void)virtio_transport_reset(vsk, NULL); 807 return true; 808 } 809 810 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 811 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 812 813 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 814 virtio_transport_wait_close(sk, sk->sk_lingertime); 815 816 if (sock_flag(sk, SOCK_DONE)) { 817 return true; 818 } 819 820 sock_hold(sk); 821 INIT_DELAYED_WORK(&vsk->close_work, 822 virtio_transport_close_timeout); 823 vsk->close_work_scheduled = true; 824 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 825 return false; 826 } 827 828 void virtio_transport_release(struct vsock_sock *vsk) 829 { 830 struct virtio_vsock_sock *vvs = vsk->trans; 831 struct virtio_vsock_pkt *pkt, *tmp; 832 struct sock *sk = &vsk->sk; 833 bool remove_sock = true; 834 835 lock_sock_nested(sk, SINGLE_DEPTH_NESTING); 836 if (sk->sk_type == SOCK_STREAM) 837 remove_sock = virtio_transport_close(vsk); 838 839 list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { 840 list_del(&pkt->list); 841 virtio_transport_free_pkt(pkt); 842 } 843 release_sock(sk); 844 845 if (remove_sock) 846 vsock_remove_sock(vsk); 847 } 848 EXPORT_SYMBOL_GPL(virtio_transport_release); 849 850 static int 851 virtio_transport_recv_connecting(struct sock *sk, 852 struct virtio_vsock_pkt *pkt) 853 { 854 struct vsock_sock *vsk = vsock_sk(sk); 855 int err; 856 int skerr; 857 858 switch (le16_to_cpu(pkt->hdr.op)) { 859 case VIRTIO_VSOCK_OP_RESPONSE: 860 sk->sk_state = TCP_ESTABLISHED; 861 sk->sk_socket->state = SS_CONNECTED; 862 vsock_insert_connected(vsk); 863 sk->sk_state_change(sk); 864 break; 865 case VIRTIO_VSOCK_OP_INVALID: 866 break; 867 case VIRTIO_VSOCK_OP_RST: 868 skerr = ECONNRESET; 869 err = 0; 870 goto destroy; 871 default: 872 skerr = EPROTO; 873 err = -EINVAL; 874 goto destroy; 875 } 876 return 0; 877 878 destroy: 879 virtio_transport_reset(vsk, pkt); 880 sk->sk_state = TCP_CLOSE; 881 sk->sk_err = skerr; 882 sk->sk_error_report(sk); 883 return err; 884 } 885 886 static void 887 virtio_transport_recv_enqueue(struct vsock_sock *vsk, 888 struct virtio_vsock_pkt *pkt) 889 { 890 struct virtio_vsock_sock *vvs = vsk->trans; 891 bool can_enqueue, free_pkt = false; 892 893 pkt->len = le32_to_cpu(pkt->hdr.len); 894 pkt->off = 0; 895 896 spin_lock_bh(&vvs->rx_lock); 897 898 can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt); 899 if (!can_enqueue) { 900 free_pkt = true; 901 goto out; 902 } 903 904 /* Try to copy small packets into the buffer of last packet queued, 905 * to avoid wasting memory queueing the entire buffer with a small 906 * payload. 907 */ 908 if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) { 909 struct virtio_vsock_pkt *last_pkt; 910 911 last_pkt = list_last_entry(&vvs->rx_queue, 912 struct virtio_vsock_pkt, list); 913 914 /* If there is space in the last packet queued, we copy the 915 * new packet in its buffer. 916 */ 917 if (pkt->len <= last_pkt->buf_len - last_pkt->len) { 918 memcpy(last_pkt->buf + last_pkt->len, pkt->buf, 919 pkt->len); 920 last_pkt->len += pkt->len; 921 free_pkt = true; 922 goto out; 923 } 924 } 925 926 list_add_tail(&pkt->list, &vvs->rx_queue); 927 928 out: 929 spin_unlock_bh(&vvs->rx_lock); 930 if (free_pkt) 931 virtio_transport_free_pkt(pkt); 932 } 933 934 static int 935 virtio_transport_recv_connected(struct sock *sk, 936 struct virtio_vsock_pkt *pkt) 937 { 938 struct vsock_sock *vsk = vsock_sk(sk); 939 int err = 0; 940 941 switch (le16_to_cpu(pkt->hdr.op)) { 942 case VIRTIO_VSOCK_OP_RW: 943 virtio_transport_recv_enqueue(vsk, pkt); 944 sk->sk_data_ready(sk); 945 return err; 946 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 947 sk->sk_write_space(sk); 948 break; 949 case VIRTIO_VSOCK_OP_SHUTDOWN: 950 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 951 vsk->peer_shutdown |= RCV_SHUTDOWN; 952 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 953 vsk->peer_shutdown |= SEND_SHUTDOWN; 954 if (vsk->peer_shutdown == SHUTDOWN_MASK && 955 vsock_stream_has_data(vsk) <= 0 && 956 !sock_flag(sk, SOCK_DONE)) { 957 (void)virtio_transport_reset(vsk, NULL); 958 959 virtio_transport_do_close(vsk, true); 960 } 961 if (le32_to_cpu(pkt->hdr.flags)) 962 sk->sk_state_change(sk); 963 break; 964 case VIRTIO_VSOCK_OP_RST: 965 virtio_transport_do_close(vsk, true); 966 break; 967 default: 968 err = -EINVAL; 969 break; 970 } 971 972 virtio_transport_free_pkt(pkt); 973 return err; 974 } 975 976 static void 977 virtio_transport_recv_disconnecting(struct sock *sk, 978 struct virtio_vsock_pkt *pkt) 979 { 980 struct vsock_sock *vsk = vsock_sk(sk); 981 982 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 983 virtio_transport_do_close(vsk, true); 984 } 985 986 static int 987 virtio_transport_send_response(struct vsock_sock *vsk, 988 struct virtio_vsock_pkt *pkt) 989 { 990 struct virtio_vsock_pkt_info info = { 991 .op = VIRTIO_VSOCK_OP_RESPONSE, 992 .type = VIRTIO_VSOCK_TYPE_STREAM, 993 .remote_cid = le64_to_cpu(pkt->hdr.src_cid), 994 .remote_port = le32_to_cpu(pkt->hdr.src_port), 995 .reply = true, 996 .vsk = vsk, 997 }; 998 999 return virtio_transport_send_pkt_info(vsk, &info); 1000 } 1001 1002 static bool virtio_transport_space_update(struct sock *sk, 1003 struct virtio_vsock_pkt *pkt) 1004 { 1005 struct vsock_sock *vsk = vsock_sk(sk); 1006 struct virtio_vsock_sock *vvs = vsk->trans; 1007 bool space_available; 1008 1009 /* Listener sockets are not associated with any transport, so we are 1010 * not able to take the state to see if there is space available in the 1011 * remote peer, but since they are only used to receive requests, we 1012 * can assume that there is always space available in the other peer. 1013 */ 1014 if (!vvs) 1015 return true; 1016 1017 /* buf_alloc and fwd_cnt is always included in the hdr */ 1018 spin_lock_bh(&vvs->tx_lock); 1019 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); 1020 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); 1021 space_available = virtio_transport_has_space(vsk); 1022 spin_unlock_bh(&vvs->tx_lock); 1023 return space_available; 1024 } 1025 1026 /* Handle server socket */ 1027 static int 1028 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt, 1029 struct virtio_transport *t) 1030 { 1031 struct vsock_sock *vsk = vsock_sk(sk); 1032 struct vsock_sock *vchild; 1033 struct sock *child; 1034 int ret; 1035 1036 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { 1037 virtio_transport_reset_no_sock(t, pkt); 1038 return -EINVAL; 1039 } 1040 1041 if (sk_acceptq_is_full(sk)) { 1042 virtio_transport_reset_no_sock(t, pkt); 1043 return -ENOMEM; 1044 } 1045 1046 child = vsock_create_connected(sk); 1047 if (!child) { 1048 virtio_transport_reset_no_sock(t, pkt); 1049 return -ENOMEM; 1050 } 1051 1052 sk_acceptq_added(sk); 1053 1054 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 1055 1056 child->sk_state = TCP_ESTABLISHED; 1057 1058 vchild = vsock_sk(child); 1059 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), 1060 le32_to_cpu(pkt->hdr.dst_port)); 1061 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), 1062 le32_to_cpu(pkt->hdr.src_port)); 1063 1064 ret = vsock_assign_transport(vchild, vsk); 1065 /* Transport assigned (looking at remote_addr) must be the same 1066 * where we received the request. 1067 */ 1068 if (ret || vchild->transport != &t->transport) { 1069 release_sock(child); 1070 virtio_transport_reset_no_sock(t, pkt); 1071 sock_put(child); 1072 return ret; 1073 } 1074 1075 if (virtio_transport_space_update(child, pkt)) 1076 child->sk_write_space(child); 1077 1078 vsock_insert_connected(vchild); 1079 vsock_enqueue_accept(sk, child); 1080 virtio_transport_send_response(vchild, pkt); 1081 1082 release_sock(child); 1083 1084 sk->sk_data_ready(sk); 1085 return 0; 1086 } 1087 1088 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 1089 * lock. 1090 */ 1091 void virtio_transport_recv_pkt(struct virtio_transport *t, 1092 struct virtio_vsock_pkt *pkt) 1093 { 1094 struct sockaddr_vm src, dst; 1095 struct vsock_sock *vsk; 1096 struct sock *sk; 1097 bool space_available; 1098 1099 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), 1100 le32_to_cpu(pkt->hdr.src_port)); 1101 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), 1102 le32_to_cpu(pkt->hdr.dst_port)); 1103 1104 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 1105 dst.svm_cid, dst.svm_port, 1106 le32_to_cpu(pkt->hdr.len), 1107 le16_to_cpu(pkt->hdr.type), 1108 le16_to_cpu(pkt->hdr.op), 1109 le32_to_cpu(pkt->hdr.flags), 1110 le32_to_cpu(pkt->hdr.buf_alloc), 1111 le32_to_cpu(pkt->hdr.fwd_cnt)); 1112 1113 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { 1114 (void)virtio_transport_reset_no_sock(t, pkt); 1115 goto free_pkt; 1116 } 1117 1118 /* The socket must be in connected or bound table 1119 * otherwise send reset back 1120 */ 1121 sk = vsock_find_connected_socket(&src, &dst); 1122 if (!sk) { 1123 sk = vsock_find_bound_socket(&dst); 1124 if (!sk) { 1125 (void)virtio_transport_reset_no_sock(t, pkt); 1126 goto free_pkt; 1127 } 1128 } 1129 1130 vsk = vsock_sk(sk); 1131 1132 space_available = virtio_transport_space_update(sk, pkt); 1133 1134 lock_sock(sk); 1135 1136 /* Update CID in case it has changed after a transport reset event */ 1137 vsk->local_addr.svm_cid = dst.svm_cid; 1138 1139 if (space_available) 1140 sk->sk_write_space(sk); 1141 1142 switch (sk->sk_state) { 1143 case TCP_LISTEN: 1144 virtio_transport_recv_listen(sk, pkt, t); 1145 virtio_transport_free_pkt(pkt); 1146 break; 1147 case TCP_SYN_SENT: 1148 virtio_transport_recv_connecting(sk, pkt); 1149 virtio_transport_free_pkt(pkt); 1150 break; 1151 case TCP_ESTABLISHED: 1152 virtio_transport_recv_connected(sk, pkt); 1153 break; 1154 case TCP_CLOSING: 1155 virtio_transport_recv_disconnecting(sk, pkt); 1156 virtio_transport_free_pkt(pkt); 1157 break; 1158 default: 1159 virtio_transport_free_pkt(pkt); 1160 break; 1161 } 1162 1163 release_sock(sk); 1164 1165 /* Release refcnt obtained when we fetched this socket out of the 1166 * bound or connected list. 1167 */ 1168 sock_put(sk); 1169 return; 1170 1171 free_pkt: 1172 virtio_transport_free_pkt(pkt); 1173 } 1174 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 1175 1176 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) 1177 { 1178 kfree(pkt->buf); 1179 kfree(pkt); 1180 } 1181 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); 1182 1183 MODULE_LICENSE("GPL v2"); 1184 MODULE_AUTHOR("Asias He"); 1185 MODULE_DESCRIPTION("common code for virtio vsock"); 1186