1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * common code for virtio vsock 4 * 5 * Copyright (C) 2013-2015 Red Hat, Inc. 6 * Author: Asias He <asias@redhat.com> 7 * Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 #include <linux/spinlock.h> 10 #include <linux/module.h> 11 #include <linux/sched/signal.h> 12 #include <linux/ctype.h> 13 #include <linux/list.h> 14 #include <linux/virtio.h> 15 #include <linux/virtio_ids.h> 16 #include <linux/virtio_config.h> 17 #include <linux/virtio_vsock.h> 18 #include <uapi/linux/vsockmon.h> 19 20 #include <net/sock.h> 21 #include <net/af_vsock.h> 22 23 #define CREATE_TRACE_POINTS 24 #include <trace/events/vsock_virtio_transport_common.h> 25 26 /* How long to wait for graceful shutdown of a connection */ 27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 28 29 static const struct virtio_transport *virtio_transport_get_ops(void) 30 { 31 const struct vsock_transport *t = vsock_core_get_transport(); 32 33 return container_of(t, struct virtio_transport, transport); 34 } 35 36 static struct virtio_vsock_pkt * 37 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, 38 size_t len, 39 u32 src_cid, 40 u32 src_port, 41 u32 dst_cid, 42 u32 dst_port) 43 { 44 struct virtio_vsock_pkt *pkt; 45 int err; 46 47 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 48 if (!pkt) 49 return NULL; 50 51 pkt->hdr.type = cpu_to_le16(info->type); 52 pkt->hdr.op = cpu_to_le16(info->op); 53 pkt->hdr.src_cid = cpu_to_le64(src_cid); 54 pkt->hdr.dst_cid = cpu_to_le64(dst_cid); 55 pkt->hdr.src_port = cpu_to_le32(src_port); 56 pkt->hdr.dst_port = cpu_to_le32(dst_port); 57 pkt->hdr.flags = cpu_to_le32(info->flags); 58 pkt->len = len; 59 pkt->hdr.len = cpu_to_le32(len); 60 pkt->reply = info->reply; 61 pkt->vsk = info->vsk; 62 63 if (info->msg && len > 0) { 64 pkt->buf = kmalloc(len, GFP_KERNEL); 65 if (!pkt->buf) 66 goto out_pkt; 67 err = memcpy_from_msg(pkt->buf, info->msg, len); 68 if (err) 69 goto out; 70 } 71 72 trace_virtio_transport_alloc_pkt(src_cid, src_port, 73 dst_cid, dst_port, 74 len, 75 info->type, 76 info->op, 77 info->flags); 78 79 return pkt; 80 81 out: 82 kfree(pkt->buf); 83 out_pkt: 84 kfree(pkt); 85 return NULL; 86 } 87 88 /* Packet capture */ 89 static struct sk_buff *virtio_transport_build_skb(void *opaque) 90 { 91 struct virtio_vsock_pkt *pkt = opaque; 92 struct af_vsockmon_hdr *hdr; 93 struct sk_buff *skb; 94 95 skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len, 96 GFP_ATOMIC); 97 if (!skb) 98 return NULL; 99 100 hdr = skb_put(skb, sizeof(*hdr)); 101 102 /* pkt->hdr is little-endian so no need to byteswap here */ 103 hdr->src_cid = pkt->hdr.src_cid; 104 hdr->src_port = pkt->hdr.src_port; 105 hdr->dst_cid = pkt->hdr.dst_cid; 106 hdr->dst_port = pkt->hdr.dst_port; 107 108 hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO); 109 hdr->len = cpu_to_le16(sizeof(pkt->hdr)); 110 memset(hdr->reserved, 0, sizeof(hdr->reserved)); 111 112 switch (le16_to_cpu(pkt->hdr.op)) { 113 case VIRTIO_VSOCK_OP_REQUEST: 114 case VIRTIO_VSOCK_OP_RESPONSE: 115 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); 116 break; 117 case VIRTIO_VSOCK_OP_RST: 118 case VIRTIO_VSOCK_OP_SHUTDOWN: 119 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT); 120 break; 121 case VIRTIO_VSOCK_OP_RW: 122 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD); 123 break; 124 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 125 case VIRTIO_VSOCK_OP_CREDIT_REQUEST: 126 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); 127 break; 128 default: 129 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN); 130 break; 131 } 132 133 skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr)); 134 135 if (pkt->len) { 136 skb_put_data(skb, pkt->buf, pkt->len); 137 } 138 139 return skb; 140 } 141 142 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt) 143 { 144 vsock_deliver_tap(virtio_transport_build_skb, pkt); 145 } 146 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); 147 148 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 149 struct virtio_vsock_pkt_info *info) 150 { 151 u32 src_cid, src_port, dst_cid, dst_port; 152 struct virtio_vsock_sock *vvs; 153 struct virtio_vsock_pkt *pkt; 154 u32 pkt_len = info->pkt_len; 155 156 src_cid = vm_sockets_get_local_cid(); 157 src_port = vsk->local_addr.svm_port; 158 if (!info->remote_cid) { 159 dst_cid = vsk->remote_addr.svm_cid; 160 dst_port = vsk->remote_addr.svm_port; 161 } else { 162 dst_cid = info->remote_cid; 163 dst_port = info->remote_port; 164 } 165 166 vvs = vsk->trans; 167 168 /* we can send less than pkt_len bytes */ 169 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) 170 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; 171 172 /* virtio_transport_get_credit might return less than pkt_len credit */ 173 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 174 175 /* Do not send zero length OP_RW pkt */ 176 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 177 return pkt_len; 178 179 pkt = virtio_transport_alloc_pkt(info, pkt_len, 180 src_cid, src_port, 181 dst_cid, dst_port); 182 if (!pkt) { 183 virtio_transport_put_credit(vvs, pkt_len); 184 return -ENOMEM; 185 } 186 187 virtio_transport_inc_tx_pkt(vvs, pkt); 188 189 return virtio_transport_get_ops()->send_pkt(pkt); 190 } 191 192 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 193 struct virtio_vsock_pkt *pkt) 194 { 195 vvs->rx_bytes += pkt->len; 196 } 197 198 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 199 struct virtio_vsock_pkt *pkt) 200 { 201 vvs->rx_bytes -= pkt->len; 202 vvs->fwd_cnt += pkt->len; 203 } 204 205 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) 206 { 207 spin_lock_bh(&vvs->tx_lock); 208 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 209 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); 210 spin_unlock_bh(&vvs->tx_lock); 211 } 212 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 213 214 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 215 { 216 u32 ret; 217 218 spin_lock_bh(&vvs->tx_lock); 219 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 220 if (ret > credit) 221 ret = credit; 222 vvs->tx_cnt += ret; 223 spin_unlock_bh(&vvs->tx_lock); 224 225 return ret; 226 } 227 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 228 229 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 230 { 231 spin_lock_bh(&vvs->tx_lock); 232 vvs->tx_cnt -= credit; 233 spin_unlock_bh(&vvs->tx_lock); 234 } 235 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 236 237 static int virtio_transport_send_credit_update(struct vsock_sock *vsk, 238 int type, 239 struct virtio_vsock_hdr *hdr) 240 { 241 struct virtio_vsock_pkt_info info = { 242 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 243 .type = type, 244 .vsk = vsk, 245 }; 246 247 return virtio_transport_send_pkt_info(vsk, &info); 248 } 249 250 static ssize_t 251 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 252 struct msghdr *msg, 253 size_t len) 254 { 255 struct virtio_vsock_sock *vvs = vsk->trans; 256 struct virtio_vsock_pkt *pkt; 257 size_t bytes, total = 0; 258 int err = -EFAULT; 259 260 spin_lock_bh(&vvs->rx_lock); 261 while (total < len && !list_empty(&vvs->rx_queue)) { 262 pkt = list_first_entry(&vvs->rx_queue, 263 struct virtio_vsock_pkt, list); 264 265 bytes = len - total; 266 if (bytes > pkt->len - pkt->off) 267 bytes = pkt->len - pkt->off; 268 269 /* sk_lock is held by caller so no one else can dequeue. 270 * Unlock rx_lock since memcpy_to_msg() may sleep. 271 */ 272 spin_unlock_bh(&vvs->rx_lock); 273 274 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); 275 if (err) 276 goto out; 277 278 spin_lock_bh(&vvs->rx_lock); 279 280 total += bytes; 281 pkt->off += bytes; 282 if (pkt->off == pkt->len) { 283 virtio_transport_dec_rx_pkt(vvs, pkt); 284 list_del(&pkt->list); 285 virtio_transport_free_pkt(pkt); 286 } 287 } 288 spin_unlock_bh(&vvs->rx_lock); 289 290 /* Send a credit pkt to peer */ 291 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, 292 NULL); 293 294 return total; 295 296 out: 297 if (total) 298 err = total; 299 return err; 300 } 301 302 ssize_t 303 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 304 struct msghdr *msg, 305 size_t len, int flags) 306 { 307 if (flags & MSG_PEEK) 308 return -EOPNOTSUPP; 309 310 return virtio_transport_stream_do_dequeue(vsk, msg, len); 311 } 312 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 313 314 int 315 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 316 struct msghdr *msg, 317 size_t len, int flags) 318 { 319 return -EOPNOTSUPP; 320 } 321 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 322 323 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 324 { 325 struct virtio_vsock_sock *vvs = vsk->trans; 326 s64 bytes; 327 328 spin_lock_bh(&vvs->rx_lock); 329 bytes = vvs->rx_bytes; 330 spin_unlock_bh(&vvs->rx_lock); 331 332 return bytes; 333 } 334 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 335 336 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 337 { 338 struct virtio_vsock_sock *vvs = vsk->trans; 339 s64 bytes; 340 341 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 342 if (bytes < 0) 343 bytes = 0; 344 345 return bytes; 346 } 347 348 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 349 { 350 struct virtio_vsock_sock *vvs = vsk->trans; 351 s64 bytes; 352 353 spin_lock_bh(&vvs->tx_lock); 354 bytes = virtio_transport_has_space(vsk); 355 spin_unlock_bh(&vvs->tx_lock); 356 357 return bytes; 358 } 359 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 360 361 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 362 struct vsock_sock *psk) 363 { 364 struct virtio_vsock_sock *vvs; 365 366 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 367 if (!vvs) 368 return -ENOMEM; 369 370 vsk->trans = vvs; 371 vvs->vsk = vsk; 372 if (psk) { 373 struct virtio_vsock_sock *ptrans = psk->trans; 374 375 vvs->buf_size = ptrans->buf_size; 376 vvs->buf_size_min = ptrans->buf_size_min; 377 vvs->buf_size_max = ptrans->buf_size_max; 378 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 379 } else { 380 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; 381 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; 382 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; 383 } 384 385 vvs->buf_alloc = vvs->buf_size; 386 387 spin_lock_init(&vvs->rx_lock); 388 spin_lock_init(&vvs->tx_lock); 389 INIT_LIST_HEAD(&vvs->rx_queue); 390 391 return 0; 392 } 393 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 394 395 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) 396 { 397 struct virtio_vsock_sock *vvs = vsk->trans; 398 399 return vvs->buf_size; 400 } 401 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); 402 403 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) 404 { 405 struct virtio_vsock_sock *vvs = vsk->trans; 406 407 return vvs->buf_size_min; 408 } 409 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); 410 411 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) 412 { 413 struct virtio_vsock_sock *vvs = vsk->trans; 414 415 return vvs->buf_size_max; 416 } 417 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); 418 419 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) 420 { 421 struct virtio_vsock_sock *vvs = vsk->trans; 422 423 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 424 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 425 if (val < vvs->buf_size_min) 426 vvs->buf_size_min = val; 427 if (val > vvs->buf_size_max) 428 vvs->buf_size_max = val; 429 vvs->buf_size = val; 430 vvs->buf_alloc = val; 431 } 432 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); 433 434 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) 435 { 436 struct virtio_vsock_sock *vvs = vsk->trans; 437 438 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 439 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 440 if (val > vvs->buf_size) 441 vvs->buf_size = val; 442 vvs->buf_size_min = val; 443 } 444 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); 445 446 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) 447 { 448 struct virtio_vsock_sock *vvs = vsk->trans; 449 450 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 451 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 452 if (val < vvs->buf_size) 453 vvs->buf_size = val; 454 vvs->buf_size_max = val; 455 } 456 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); 457 458 int 459 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 460 size_t target, 461 bool *data_ready_now) 462 { 463 if (vsock_stream_has_data(vsk)) 464 *data_ready_now = true; 465 else 466 *data_ready_now = false; 467 468 return 0; 469 } 470 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 471 472 int 473 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 474 size_t target, 475 bool *space_avail_now) 476 { 477 s64 free_space; 478 479 free_space = vsock_stream_has_space(vsk); 480 if (free_space > 0) 481 *space_avail_now = true; 482 else if (free_space == 0) 483 *space_avail_now = false; 484 485 return 0; 486 } 487 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 488 489 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 490 size_t target, struct vsock_transport_recv_notify_data *data) 491 { 492 return 0; 493 } 494 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 495 496 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 497 size_t target, struct vsock_transport_recv_notify_data *data) 498 { 499 return 0; 500 } 501 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 502 503 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 504 size_t target, struct vsock_transport_recv_notify_data *data) 505 { 506 return 0; 507 } 508 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 509 510 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 511 size_t target, ssize_t copied, bool data_read, 512 struct vsock_transport_recv_notify_data *data) 513 { 514 return 0; 515 } 516 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 517 518 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 519 struct vsock_transport_send_notify_data *data) 520 { 521 return 0; 522 } 523 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 524 525 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 526 struct vsock_transport_send_notify_data *data) 527 { 528 return 0; 529 } 530 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 531 532 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 533 struct vsock_transport_send_notify_data *data) 534 { 535 return 0; 536 } 537 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 538 539 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 540 ssize_t written, struct vsock_transport_send_notify_data *data) 541 { 542 return 0; 543 } 544 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 545 546 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 547 { 548 struct virtio_vsock_sock *vvs = vsk->trans; 549 550 return vvs->buf_size; 551 } 552 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 553 554 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 555 { 556 return true; 557 } 558 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 559 560 bool virtio_transport_stream_allow(u32 cid, u32 port) 561 { 562 return true; 563 } 564 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 565 566 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 567 struct sockaddr_vm *addr) 568 { 569 return -EOPNOTSUPP; 570 } 571 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 572 573 bool virtio_transport_dgram_allow(u32 cid, u32 port) 574 { 575 return false; 576 } 577 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 578 579 int virtio_transport_connect(struct vsock_sock *vsk) 580 { 581 struct virtio_vsock_pkt_info info = { 582 .op = VIRTIO_VSOCK_OP_REQUEST, 583 .type = VIRTIO_VSOCK_TYPE_STREAM, 584 .vsk = vsk, 585 }; 586 587 return virtio_transport_send_pkt_info(vsk, &info); 588 } 589 EXPORT_SYMBOL_GPL(virtio_transport_connect); 590 591 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 592 { 593 struct virtio_vsock_pkt_info info = { 594 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 595 .type = VIRTIO_VSOCK_TYPE_STREAM, 596 .flags = (mode & RCV_SHUTDOWN ? 597 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 598 (mode & SEND_SHUTDOWN ? 599 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 600 .vsk = vsk, 601 }; 602 603 return virtio_transport_send_pkt_info(vsk, &info); 604 } 605 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 606 607 int 608 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 609 struct sockaddr_vm *remote_addr, 610 struct msghdr *msg, 611 size_t dgram_len) 612 { 613 return -EOPNOTSUPP; 614 } 615 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 616 617 ssize_t 618 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 619 struct msghdr *msg, 620 size_t len) 621 { 622 struct virtio_vsock_pkt_info info = { 623 .op = VIRTIO_VSOCK_OP_RW, 624 .type = VIRTIO_VSOCK_TYPE_STREAM, 625 .msg = msg, 626 .pkt_len = len, 627 .vsk = vsk, 628 }; 629 630 return virtio_transport_send_pkt_info(vsk, &info); 631 } 632 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 633 634 void virtio_transport_destruct(struct vsock_sock *vsk) 635 { 636 struct virtio_vsock_sock *vvs = vsk->trans; 637 638 kfree(vvs); 639 } 640 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 641 642 static int virtio_transport_reset(struct vsock_sock *vsk, 643 struct virtio_vsock_pkt *pkt) 644 { 645 struct virtio_vsock_pkt_info info = { 646 .op = VIRTIO_VSOCK_OP_RST, 647 .type = VIRTIO_VSOCK_TYPE_STREAM, 648 .reply = !!pkt, 649 .vsk = vsk, 650 }; 651 652 /* Send RST only if the original pkt is not a RST pkt */ 653 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 654 return 0; 655 656 return virtio_transport_send_pkt_info(vsk, &info); 657 } 658 659 /* Normally packets are associated with a socket. There may be no socket if an 660 * attempt was made to connect to a socket that does not exist. 661 */ 662 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) 663 { 664 const struct virtio_transport *t; 665 struct virtio_vsock_pkt *reply; 666 struct virtio_vsock_pkt_info info = { 667 .op = VIRTIO_VSOCK_OP_RST, 668 .type = le16_to_cpu(pkt->hdr.type), 669 .reply = true, 670 }; 671 672 /* Send RST only if the original pkt is not a RST pkt */ 673 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 674 return 0; 675 676 reply = virtio_transport_alloc_pkt(&info, 0, 677 le64_to_cpu(pkt->hdr.dst_cid), 678 le32_to_cpu(pkt->hdr.dst_port), 679 le64_to_cpu(pkt->hdr.src_cid), 680 le32_to_cpu(pkt->hdr.src_port)); 681 if (!reply) 682 return -ENOMEM; 683 684 t = virtio_transport_get_ops(); 685 if (!t) { 686 virtio_transport_free_pkt(reply); 687 return -ENOTCONN; 688 } 689 690 return t->send_pkt(reply); 691 } 692 693 static void virtio_transport_wait_close(struct sock *sk, long timeout) 694 { 695 if (timeout) { 696 DEFINE_WAIT_FUNC(wait, woken_wake_function); 697 698 add_wait_queue(sk_sleep(sk), &wait); 699 700 do { 701 if (sk_wait_event(sk, &timeout, 702 sock_flag(sk, SOCK_DONE), &wait)) 703 break; 704 } while (!signal_pending(current) && timeout); 705 706 remove_wait_queue(sk_sleep(sk), &wait); 707 } 708 } 709 710 static void virtio_transport_do_close(struct vsock_sock *vsk, 711 bool cancel_timeout) 712 { 713 struct sock *sk = sk_vsock(vsk); 714 715 sock_set_flag(sk, SOCK_DONE); 716 vsk->peer_shutdown = SHUTDOWN_MASK; 717 if (vsock_stream_has_data(vsk) <= 0) 718 sk->sk_state = TCP_CLOSING; 719 sk->sk_state_change(sk); 720 721 if (vsk->close_work_scheduled && 722 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 723 vsk->close_work_scheduled = false; 724 725 vsock_remove_sock(vsk); 726 727 /* Release refcnt obtained when we scheduled the timeout */ 728 sock_put(sk); 729 } 730 } 731 732 static void virtio_transport_close_timeout(struct work_struct *work) 733 { 734 struct vsock_sock *vsk = 735 container_of(work, struct vsock_sock, close_work.work); 736 struct sock *sk = sk_vsock(vsk); 737 738 sock_hold(sk); 739 lock_sock(sk); 740 741 if (!sock_flag(sk, SOCK_DONE)) { 742 (void)virtio_transport_reset(vsk, NULL); 743 744 virtio_transport_do_close(vsk, false); 745 } 746 747 vsk->close_work_scheduled = false; 748 749 release_sock(sk); 750 sock_put(sk); 751 } 752 753 /* User context, vsk->sk is locked */ 754 static bool virtio_transport_close(struct vsock_sock *vsk) 755 { 756 struct sock *sk = &vsk->sk; 757 758 if (!(sk->sk_state == TCP_ESTABLISHED || 759 sk->sk_state == TCP_CLOSING)) 760 return true; 761 762 /* Already received SHUTDOWN from peer, reply with RST */ 763 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 764 (void)virtio_transport_reset(vsk, NULL); 765 return true; 766 } 767 768 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 769 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 770 771 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 772 virtio_transport_wait_close(sk, sk->sk_lingertime); 773 774 if (sock_flag(sk, SOCK_DONE)) { 775 return true; 776 } 777 778 sock_hold(sk); 779 INIT_DELAYED_WORK(&vsk->close_work, 780 virtio_transport_close_timeout); 781 vsk->close_work_scheduled = true; 782 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 783 return false; 784 } 785 786 void virtio_transport_release(struct vsock_sock *vsk) 787 { 788 struct virtio_vsock_sock *vvs = vsk->trans; 789 struct virtio_vsock_pkt *pkt, *tmp; 790 struct sock *sk = &vsk->sk; 791 bool remove_sock = true; 792 793 lock_sock(sk); 794 if (sk->sk_type == SOCK_STREAM) 795 remove_sock = virtio_transport_close(vsk); 796 797 list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { 798 list_del(&pkt->list); 799 virtio_transport_free_pkt(pkt); 800 } 801 release_sock(sk); 802 803 if (remove_sock) 804 vsock_remove_sock(vsk); 805 } 806 EXPORT_SYMBOL_GPL(virtio_transport_release); 807 808 static int 809 virtio_transport_recv_connecting(struct sock *sk, 810 struct virtio_vsock_pkt *pkt) 811 { 812 struct vsock_sock *vsk = vsock_sk(sk); 813 int err; 814 int skerr; 815 816 switch (le16_to_cpu(pkt->hdr.op)) { 817 case VIRTIO_VSOCK_OP_RESPONSE: 818 sk->sk_state = TCP_ESTABLISHED; 819 sk->sk_socket->state = SS_CONNECTED; 820 vsock_insert_connected(vsk); 821 sk->sk_state_change(sk); 822 break; 823 case VIRTIO_VSOCK_OP_INVALID: 824 break; 825 case VIRTIO_VSOCK_OP_RST: 826 skerr = ECONNRESET; 827 err = 0; 828 goto destroy; 829 default: 830 skerr = EPROTO; 831 err = -EINVAL; 832 goto destroy; 833 } 834 return 0; 835 836 destroy: 837 virtio_transport_reset(vsk, pkt); 838 sk->sk_state = TCP_CLOSE; 839 sk->sk_err = skerr; 840 sk->sk_error_report(sk); 841 return err; 842 } 843 844 static int 845 virtio_transport_recv_connected(struct sock *sk, 846 struct virtio_vsock_pkt *pkt) 847 { 848 struct vsock_sock *vsk = vsock_sk(sk); 849 struct virtio_vsock_sock *vvs = vsk->trans; 850 int err = 0; 851 852 switch (le16_to_cpu(pkt->hdr.op)) { 853 case VIRTIO_VSOCK_OP_RW: 854 pkt->len = le32_to_cpu(pkt->hdr.len); 855 pkt->off = 0; 856 857 spin_lock_bh(&vvs->rx_lock); 858 virtio_transport_inc_rx_pkt(vvs, pkt); 859 list_add_tail(&pkt->list, &vvs->rx_queue); 860 spin_unlock_bh(&vvs->rx_lock); 861 862 sk->sk_data_ready(sk); 863 return err; 864 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 865 sk->sk_write_space(sk); 866 break; 867 case VIRTIO_VSOCK_OP_SHUTDOWN: 868 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 869 vsk->peer_shutdown |= RCV_SHUTDOWN; 870 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 871 vsk->peer_shutdown |= SEND_SHUTDOWN; 872 if (vsk->peer_shutdown == SHUTDOWN_MASK && 873 vsock_stream_has_data(vsk) <= 0) { 874 sock_set_flag(sk, SOCK_DONE); 875 sk->sk_state = TCP_CLOSING; 876 } 877 if (le32_to_cpu(pkt->hdr.flags)) 878 sk->sk_state_change(sk); 879 break; 880 case VIRTIO_VSOCK_OP_RST: 881 virtio_transport_do_close(vsk, true); 882 break; 883 default: 884 err = -EINVAL; 885 break; 886 } 887 888 virtio_transport_free_pkt(pkt); 889 return err; 890 } 891 892 static void 893 virtio_transport_recv_disconnecting(struct sock *sk, 894 struct virtio_vsock_pkt *pkt) 895 { 896 struct vsock_sock *vsk = vsock_sk(sk); 897 898 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 899 virtio_transport_do_close(vsk, true); 900 } 901 902 static int 903 virtio_transport_send_response(struct vsock_sock *vsk, 904 struct virtio_vsock_pkt *pkt) 905 { 906 struct virtio_vsock_pkt_info info = { 907 .op = VIRTIO_VSOCK_OP_RESPONSE, 908 .type = VIRTIO_VSOCK_TYPE_STREAM, 909 .remote_cid = le64_to_cpu(pkt->hdr.src_cid), 910 .remote_port = le32_to_cpu(pkt->hdr.src_port), 911 .reply = true, 912 .vsk = vsk, 913 }; 914 915 return virtio_transport_send_pkt_info(vsk, &info); 916 } 917 918 /* Handle server socket */ 919 static int 920 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) 921 { 922 struct vsock_sock *vsk = vsock_sk(sk); 923 struct vsock_sock *vchild; 924 struct sock *child; 925 926 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { 927 virtio_transport_reset(vsk, pkt); 928 return -EINVAL; 929 } 930 931 if (sk_acceptq_is_full(sk)) { 932 virtio_transport_reset(vsk, pkt); 933 return -ENOMEM; 934 } 935 936 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, 937 sk->sk_type, 0); 938 if (!child) { 939 virtio_transport_reset(vsk, pkt); 940 return -ENOMEM; 941 } 942 943 sk->sk_ack_backlog++; 944 945 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 946 947 child->sk_state = TCP_ESTABLISHED; 948 949 vchild = vsock_sk(child); 950 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), 951 le32_to_cpu(pkt->hdr.dst_port)); 952 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), 953 le32_to_cpu(pkt->hdr.src_port)); 954 955 vsock_insert_connected(vchild); 956 vsock_enqueue_accept(sk, child); 957 virtio_transport_send_response(vchild, pkt); 958 959 release_sock(child); 960 961 sk->sk_data_ready(sk); 962 return 0; 963 } 964 965 static bool virtio_transport_space_update(struct sock *sk, 966 struct virtio_vsock_pkt *pkt) 967 { 968 struct vsock_sock *vsk = vsock_sk(sk); 969 struct virtio_vsock_sock *vvs = vsk->trans; 970 bool space_available; 971 972 /* buf_alloc and fwd_cnt is always included in the hdr */ 973 spin_lock_bh(&vvs->tx_lock); 974 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); 975 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); 976 space_available = virtio_transport_has_space(vsk); 977 spin_unlock_bh(&vvs->tx_lock); 978 return space_available; 979 } 980 981 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 982 * lock. 983 */ 984 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) 985 { 986 struct sockaddr_vm src, dst; 987 struct vsock_sock *vsk; 988 struct sock *sk; 989 bool space_available; 990 991 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), 992 le32_to_cpu(pkt->hdr.src_port)); 993 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), 994 le32_to_cpu(pkt->hdr.dst_port)); 995 996 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 997 dst.svm_cid, dst.svm_port, 998 le32_to_cpu(pkt->hdr.len), 999 le16_to_cpu(pkt->hdr.type), 1000 le16_to_cpu(pkt->hdr.op), 1001 le32_to_cpu(pkt->hdr.flags), 1002 le32_to_cpu(pkt->hdr.buf_alloc), 1003 le32_to_cpu(pkt->hdr.fwd_cnt)); 1004 1005 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { 1006 (void)virtio_transport_reset_no_sock(pkt); 1007 goto free_pkt; 1008 } 1009 1010 /* The socket must be in connected or bound table 1011 * otherwise send reset back 1012 */ 1013 sk = vsock_find_connected_socket(&src, &dst); 1014 if (!sk) { 1015 sk = vsock_find_bound_socket(&dst); 1016 if (!sk) { 1017 (void)virtio_transport_reset_no_sock(pkt); 1018 goto free_pkt; 1019 } 1020 } 1021 1022 vsk = vsock_sk(sk); 1023 1024 space_available = virtio_transport_space_update(sk, pkt); 1025 1026 lock_sock(sk); 1027 1028 /* Update CID in case it has changed after a transport reset event */ 1029 vsk->local_addr.svm_cid = dst.svm_cid; 1030 1031 if (space_available) 1032 sk->sk_write_space(sk); 1033 1034 switch (sk->sk_state) { 1035 case TCP_LISTEN: 1036 virtio_transport_recv_listen(sk, pkt); 1037 virtio_transport_free_pkt(pkt); 1038 break; 1039 case TCP_SYN_SENT: 1040 virtio_transport_recv_connecting(sk, pkt); 1041 virtio_transport_free_pkt(pkt); 1042 break; 1043 case TCP_ESTABLISHED: 1044 virtio_transport_recv_connected(sk, pkt); 1045 break; 1046 case TCP_CLOSING: 1047 virtio_transport_recv_disconnecting(sk, pkt); 1048 virtio_transport_free_pkt(pkt); 1049 break; 1050 default: 1051 virtio_transport_free_pkt(pkt); 1052 break; 1053 } 1054 release_sock(sk); 1055 1056 /* Release refcnt obtained when we fetched this socket out of the 1057 * bound or connected list. 1058 */ 1059 sock_put(sk); 1060 return; 1061 1062 free_pkt: 1063 virtio_transport_free_pkt(pkt); 1064 } 1065 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 1066 1067 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) 1068 { 1069 kfree(pkt->buf); 1070 kfree(pkt); 1071 } 1072 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); 1073 1074 MODULE_LICENSE("GPL v2"); 1075 MODULE_AUTHOR("Asias He"); 1076 MODULE_DESCRIPTION("common code for virtio vsock"); 1077