1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * common code for virtio vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18 #include <uapi/linux/vsockmon.h>
19 
20 #include <net/sock.h>
21 #include <net/af_vsock.h>
22 
23 #define CREATE_TRACE_POINTS
24 #include <trace/events/vsock_virtio_transport_common.h>
25 
26 /* How long to wait for graceful shutdown of a connection */
27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
28 
29 static const struct virtio_transport *virtio_transport_get_ops(void)
30 {
31 	const struct vsock_transport *t = vsock_core_get_transport();
32 
33 	return container_of(t, struct virtio_transport, transport);
34 }
35 
36 static struct virtio_vsock_pkt *
37 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
38 			   size_t len,
39 			   u32 src_cid,
40 			   u32 src_port,
41 			   u32 dst_cid,
42 			   u32 dst_port)
43 {
44 	struct virtio_vsock_pkt *pkt;
45 	int err;
46 
47 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
48 	if (!pkt)
49 		return NULL;
50 
51 	pkt->hdr.type		= cpu_to_le16(info->type);
52 	pkt->hdr.op		= cpu_to_le16(info->op);
53 	pkt->hdr.src_cid	= cpu_to_le64(src_cid);
54 	pkt->hdr.dst_cid	= cpu_to_le64(dst_cid);
55 	pkt->hdr.src_port	= cpu_to_le32(src_port);
56 	pkt->hdr.dst_port	= cpu_to_le32(dst_port);
57 	pkt->hdr.flags		= cpu_to_le32(info->flags);
58 	pkt->len		= len;
59 	pkt->hdr.len		= cpu_to_le32(len);
60 	pkt->reply		= info->reply;
61 	pkt->vsk		= info->vsk;
62 
63 	if (info->msg && len > 0) {
64 		pkt->buf = kmalloc(len, GFP_KERNEL);
65 		if (!pkt->buf)
66 			goto out_pkt;
67 		err = memcpy_from_msg(pkt->buf, info->msg, len);
68 		if (err)
69 			goto out;
70 	}
71 
72 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
73 					 dst_cid, dst_port,
74 					 len,
75 					 info->type,
76 					 info->op,
77 					 info->flags);
78 
79 	return pkt;
80 
81 out:
82 	kfree(pkt->buf);
83 out_pkt:
84 	kfree(pkt);
85 	return NULL;
86 }
87 
88 /* Packet capture */
89 static struct sk_buff *virtio_transport_build_skb(void *opaque)
90 {
91 	struct virtio_vsock_pkt *pkt = opaque;
92 	struct af_vsockmon_hdr *hdr;
93 	struct sk_buff *skb;
94 
95 	skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
96 			GFP_ATOMIC);
97 	if (!skb)
98 		return NULL;
99 
100 	hdr = skb_put(skb, sizeof(*hdr));
101 
102 	/* pkt->hdr is little-endian so no need to byteswap here */
103 	hdr->src_cid = pkt->hdr.src_cid;
104 	hdr->src_port = pkt->hdr.src_port;
105 	hdr->dst_cid = pkt->hdr.dst_cid;
106 	hdr->dst_port = pkt->hdr.dst_port;
107 
108 	hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
109 	hdr->len = cpu_to_le16(sizeof(pkt->hdr));
110 	memset(hdr->reserved, 0, sizeof(hdr->reserved));
111 
112 	switch (le16_to_cpu(pkt->hdr.op)) {
113 	case VIRTIO_VSOCK_OP_REQUEST:
114 	case VIRTIO_VSOCK_OP_RESPONSE:
115 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
116 		break;
117 	case VIRTIO_VSOCK_OP_RST:
118 	case VIRTIO_VSOCK_OP_SHUTDOWN:
119 		hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
120 		break;
121 	case VIRTIO_VSOCK_OP_RW:
122 		hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
123 		break;
124 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
125 	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
126 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
127 		break;
128 	default:
129 		hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
130 		break;
131 	}
132 
133 	skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
134 
135 	if (pkt->len) {
136 		skb_put_data(skb, pkt->buf, pkt->len);
137 	}
138 
139 	return skb;
140 }
141 
142 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
143 {
144 	vsock_deliver_tap(virtio_transport_build_skb, pkt);
145 }
146 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
147 
148 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
149 					  struct virtio_vsock_pkt_info *info)
150 {
151 	u32 src_cid, src_port, dst_cid, dst_port;
152 	struct virtio_vsock_sock *vvs;
153 	struct virtio_vsock_pkt *pkt;
154 	u32 pkt_len = info->pkt_len;
155 
156 	src_cid = vm_sockets_get_local_cid();
157 	src_port = vsk->local_addr.svm_port;
158 	if (!info->remote_cid) {
159 		dst_cid	= vsk->remote_addr.svm_cid;
160 		dst_port = vsk->remote_addr.svm_port;
161 	} else {
162 		dst_cid = info->remote_cid;
163 		dst_port = info->remote_port;
164 	}
165 
166 	vvs = vsk->trans;
167 
168 	/* we can send less than pkt_len bytes */
169 	if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
170 		pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
171 
172 	/* virtio_transport_get_credit might return less than pkt_len credit */
173 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
174 
175 	/* Do not send zero length OP_RW pkt */
176 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
177 		return pkt_len;
178 
179 	pkt = virtio_transport_alloc_pkt(info, pkt_len,
180 					 src_cid, src_port,
181 					 dst_cid, dst_port);
182 	if (!pkt) {
183 		virtio_transport_put_credit(vvs, pkt_len);
184 		return -ENOMEM;
185 	}
186 
187 	virtio_transport_inc_tx_pkt(vvs, pkt);
188 
189 	return virtio_transport_get_ops()->send_pkt(pkt);
190 }
191 
192 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
193 					struct virtio_vsock_pkt *pkt)
194 {
195 	vvs->rx_bytes += pkt->len;
196 }
197 
198 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
199 					struct virtio_vsock_pkt *pkt)
200 {
201 	vvs->rx_bytes -= pkt->len;
202 	vvs->fwd_cnt += pkt->len;
203 }
204 
205 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
206 {
207 	spin_lock_bh(&vvs->tx_lock);
208 	pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
209 	pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
210 	spin_unlock_bh(&vvs->tx_lock);
211 }
212 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
213 
214 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
215 {
216 	u32 ret;
217 
218 	spin_lock_bh(&vvs->tx_lock);
219 	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
220 	if (ret > credit)
221 		ret = credit;
222 	vvs->tx_cnt += ret;
223 	spin_unlock_bh(&vvs->tx_lock);
224 
225 	return ret;
226 }
227 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
228 
229 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
230 {
231 	spin_lock_bh(&vvs->tx_lock);
232 	vvs->tx_cnt -= credit;
233 	spin_unlock_bh(&vvs->tx_lock);
234 }
235 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
236 
237 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
238 					       int type,
239 					       struct virtio_vsock_hdr *hdr)
240 {
241 	struct virtio_vsock_pkt_info info = {
242 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
243 		.type = type,
244 		.vsk = vsk,
245 	};
246 
247 	return virtio_transport_send_pkt_info(vsk, &info);
248 }
249 
250 static ssize_t
251 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
252 				   struct msghdr *msg,
253 				   size_t len)
254 {
255 	struct virtio_vsock_sock *vvs = vsk->trans;
256 	struct virtio_vsock_pkt *pkt;
257 	size_t bytes, total = 0;
258 	int err = -EFAULT;
259 
260 	spin_lock_bh(&vvs->rx_lock);
261 	while (total < len && !list_empty(&vvs->rx_queue)) {
262 		pkt = list_first_entry(&vvs->rx_queue,
263 				       struct virtio_vsock_pkt, list);
264 
265 		bytes = len - total;
266 		if (bytes > pkt->len - pkt->off)
267 			bytes = pkt->len - pkt->off;
268 
269 		/* sk_lock is held by caller so no one else can dequeue.
270 		 * Unlock rx_lock since memcpy_to_msg() may sleep.
271 		 */
272 		spin_unlock_bh(&vvs->rx_lock);
273 
274 		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
275 		if (err)
276 			goto out;
277 
278 		spin_lock_bh(&vvs->rx_lock);
279 
280 		total += bytes;
281 		pkt->off += bytes;
282 		if (pkt->off == pkt->len) {
283 			virtio_transport_dec_rx_pkt(vvs, pkt);
284 			list_del(&pkt->list);
285 			virtio_transport_free_pkt(pkt);
286 		}
287 	}
288 	spin_unlock_bh(&vvs->rx_lock);
289 
290 	/* Send a credit pkt to peer */
291 	virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
292 					    NULL);
293 
294 	return total;
295 
296 out:
297 	if (total)
298 		err = total;
299 	return err;
300 }
301 
302 ssize_t
303 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
304 				struct msghdr *msg,
305 				size_t len, int flags)
306 {
307 	if (flags & MSG_PEEK)
308 		return -EOPNOTSUPP;
309 
310 	return virtio_transport_stream_do_dequeue(vsk, msg, len);
311 }
312 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
313 
314 int
315 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
316 			       struct msghdr *msg,
317 			       size_t len, int flags)
318 {
319 	return -EOPNOTSUPP;
320 }
321 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
322 
323 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
324 {
325 	struct virtio_vsock_sock *vvs = vsk->trans;
326 	s64 bytes;
327 
328 	spin_lock_bh(&vvs->rx_lock);
329 	bytes = vvs->rx_bytes;
330 	spin_unlock_bh(&vvs->rx_lock);
331 
332 	return bytes;
333 }
334 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
335 
336 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
337 {
338 	struct virtio_vsock_sock *vvs = vsk->trans;
339 	s64 bytes;
340 
341 	bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
342 	if (bytes < 0)
343 		bytes = 0;
344 
345 	return bytes;
346 }
347 
348 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
349 {
350 	struct virtio_vsock_sock *vvs = vsk->trans;
351 	s64 bytes;
352 
353 	spin_lock_bh(&vvs->tx_lock);
354 	bytes = virtio_transport_has_space(vsk);
355 	spin_unlock_bh(&vvs->tx_lock);
356 
357 	return bytes;
358 }
359 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
360 
361 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
362 				    struct vsock_sock *psk)
363 {
364 	struct virtio_vsock_sock *vvs;
365 
366 	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
367 	if (!vvs)
368 		return -ENOMEM;
369 
370 	vsk->trans = vvs;
371 	vvs->vsk = vsk;
372 	if (psk) {
373 		struct virtio_vsock_sock *ptrans = psk->trans;
374 
375 		vvs->buf_size	= ptrans->buf_size;
376 		vvs->buf_size_min = ptrans->buf_size_min;
377 		vvs->buf_size_max = ptrans->buf_size_max;
378 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
379 	} else {
380 		vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
381 		vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
382 		vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
383 	}
384 
385 	vvs->buf_alloc = vvs->buf_size;
386 
387 	spin_lock_init(&vvs->rx_lock);
388 	spin_lock_init(&vvs->tx_lock);
389 	INIT_LIST_HEAD(&vvs->rx_queue);
390 
391 	return 0;
392 }
393 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
394 
395 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
396 {
397 	struct virtio_vsock_sock *vvs = vsk->trans;
398 
399 	return vvs->buf_size;
400 }
401 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
402 
403 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
404 {
405 	struct virtio_vsock_sock *vvs = vsk->trans;
406 
407 	return vvs->buf_size_min;
408 }
409 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
410 
411 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
412 {
413 	struct virtio_vsock_sock *vvs = vsk->trans;
414 
415 	return vvs->buf_size_max;
416 }
417 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
418 
419 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
420 {
421 	struct virtio_vsock_sock *vvs = vsk->trans;
422 
423 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
424 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
425 	if (val < vvs->buf_size_min)
426 		vvs->buf_size_min = val;
427 	if (val > vvs->buf_size_max)
428 		vvs->buf_size_max = val;
429 	vvs->buf_size = val;
430 	vvs->buf_alloc = val;
431 }
432 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
433 
434 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
435 {
436 	struct virtio_vsock_sock *vvs = vsk->trans;
437 
438 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
439 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
440 	if (val > vvs->buf_size)
441 		vvs->buf_size = val;
442 	vvs->buf_size_min = val;
443 }
444 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
445 
446 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
447 {
448 	struct virtio_vsock_sock *vvs = vsk->trans;
449 
450 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
451 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
452 	if (val < vvs->buf_size)
453 		vvs->buf_size = val;
454 	vvs->buf_size_max = val;
455 }
456 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
457 
458 int
459 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
460 				size_t target,
461 				bool *data_ready_now)
462 {
463 	if (vsock_stream_has_data(vsk))
464 		*data_ready_now = true;
465 	else
466 		*data_ready_now = false;
467 
468 	return 0;
469 }
470 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
471 
472 int
473 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
474 				 size_t target,
475 				 bool *space_avail_now)
476 {
477 	s64 free_space;
478 
479 	free_space = vsock_stream_has_space(vsk);
480 	if (free_space > 0)
481 		*space_avail_now = true;
482 	else if (free_space == 0)
483 		*space_avail_now = false;
484 
485 	return 0;
486 }
487 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
488 
489 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
490 	size_t target, struct vsock_transport_recv_notify_data *data)
491 {
492 	return 0;
493 }
494 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
495 
496 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
497 	size_t target, struct vsock_transport_recv_notify_data *data)
498 {
499 	return 0;
500 }
501 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
502 
503 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
504 	size_t target, struct vsock_transport_recv_notify_data *data)
505 {
506 	return 0;
507 }
508 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
509 
510 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
511 	size_t target, ssize_t copied, bool data_read,
512 	struct vsock_transport_recv_notify_data *data)
513 {
514 	return 0;
515 }
516 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
517 
518 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
519 	struct vsock_transport_send_notify_data *data)
520 {
521 	return 0;
522 }
523 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
524 
525 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
526 	struct vsock_transport_send_notify_data *data)
527 {
528 	return 0;
529 }
530 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
531 
532 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
533 	struct vsock_transport_send_notify_data *data)
534 {
535 	return 0;
536 }
537 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
538 
539 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
540 	ssize_t written, struct vsock_transport_send_notify_data *data)
541 {
542 	return 0;
543 }
544 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
545 
546 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
547 {
548 	struct virtio_vsock_sock *vvs = vsk->trans;
549 
550 	return vvs->buf_size;
551 }
552 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
553 
554 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
555 {
556 	return true;
557 }
558 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
559 
560 bool virtio_transport_stream_allow(u32 cid, u32 port)
561 {
562 	return true;
563 }
564 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
565 
566 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
567 				struct sockaddr_vm *addr)
568 {
569 	return -EOPNOTSUPP;
570 }
571 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
572 
573 bool virtio_transport_dgram_allow(u32 cid, u32 port)
574 {
575 	return false;
576 }
577 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
578 
579 int virtio_transport_connect(struct vsock_sock *vsk)
580 {
581 	struct virtio_vsock_pkt_info info = {
582 		.op = VIRTIO_VSOCK_OP_REQUEST,
583 		.type = VIRTIO_VSOCK_TYPE_STREAM,
584 		.vsk = vsk,
585 	};
586 
587 	return virtio_transport_send_pkt_info(vsk, &info);
588 }
589 EXPORT_SYMBOL_GPL(virtio_transport_connect);
590 
591 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
592 {
593 	struct virtio_vsock_pkt_info info = {
594 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
595 		.type = VIRTIO_VSOCK_TYPE_STREAM,
596 		.flags = (mode & RCV_SHUTDOWN ?
597 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
598 			 (mode & SEND_SHUTDOWN ?
599 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
600 		.vsk = vsk,
601 	};
602 
603 	return virtio_transport_send_pkt_info(vsk, &info);
604 }
605 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
606 
607 int
608 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
609 			       struct sockaddr_vm *remote_addr,
610 			       struct msghdr *msg,
611 			       size_t dgram_len)
612 {
613 	return -EOPNOTSUPP;
614 }
615 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
616 
617 ssize_t
618 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
619 				struct msghdr *msg,
620 				size_t len)
621 {
622 	struct virtio_vsock_pkt_info info = {
623 		.op = VIRTIO_VSOCK_OP_RW,
624 		.type = VIRTIO_VSOCK_TYPE_STREAM,
625 		.msg = msg,
626 		.pkt_len = len,
627 		.vsk = vsk,
628 	};
629 
630 	return virtio_transport_send_pkt_info(vsk, &info);
631 }
632 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
633 
634 void virtio_transport_destruct(struct vsock_sock *vsk)
635 {
636 	struct virtio_vsock_sock *vvs = vsk->trans;
637 
638 	kfree(vvs);
639 }
640 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
641 
642 static int virtio_transport_reset(struct vsock_sock *vsk,
643 				  struct virtio_vsock_pkt *pkt)
644 {
645 	struct virtio_vsock_pkt_info info = {
646 		.op = VIRTIO_VSOCK_OP_RST,
647 		.type = VIRTIO_VSOCK_TYPE_STREAM,
648 		.reply = !!pkt,
649 		.vsk = vsk,
650 	};
651 
652 	/* Send RST only if the original pkt is not a RST pkt */
653 	if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
654 		return 0;
655 
656 	return virtio_transport_send_pkt_info(vsk, &info);
657 }
658 
659 /* Normally packets are associated with a socket.  There may be no socket if an
660  * attempt was made to connect to a socket that does not exist.
661  */
662 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
663 {
664 	const struct virtio_transport *t;
665 	struct virtio_vsock_pkt *reply;
666 	struct virtio_vsock_pkt_info info = {
667 		.op = VIRTIO_VSOCK_OP_RST,
668 		.type = le16_to_cpu(pkt->hdr.type),
669 		.reply = true,
670 	};
671 
672 	/* Send RST only if the original pkt is not a RST pkt */
673 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
674 		return 0;
675 
676 	reply = virtio_transport_alloc_pkt(&info, 0,
677 					   le64_to_cpu(pkt->hdr.dst_cid),
678 					   le32_to_cpu(pkt->hdr.dst_port),
679 					   le64_to_cpu(pkt->hdr.src_cid),
680 					   le32_to_cpu(pkt->hdr.src_port));
681 	if (!reply)
682 		return -ENOMEM;
683 
684 	t = virtio_transport_get_ops();
685 	if (!t) {
686 		virtio_transport_free_pkt(reply);
687 		return -ENOTCONN;
688 	}
689 
690 	return t->send_pkt(reply);
691 }
692 
693 static void virtio_transport_wait_close(struct sock *sk, long timeout)
694 {
695 	if (timeout) {
696 		DEFINE_WAIT_FUNC(wait, woken_wake_function);
697 
698 		add_wait_queue(sk_sleep(sk), &wait);
699 
700 		do {
701 			if (sk_wait_event(sk, &timeout,
702 					  sock_flag(sk, SOCK_DONE), &wait))
703 				break;
704 		} while (!signal_pending(current) && timeout);
705 
706 		remove_wait_queue(sk_sleep(sk), &wait);
707 	}
708 }
709 
710 static void virtio_transport_do_close(struct vsock_sock *vsk,
711 				      bool cancel_timeout)
712 {
713 	struct sock *sk = sk_vsock(vsk);
714 
715 	sock_set_flag(sk, SOCK_DONE);
716 	vsk->peer_shutdown = SHUTDOWN_MASK;
717 	if (vsock_stream_has_data(vsk) <= 0)
718 		sk->sk_state = TCP_CLOSING;
719 	sk->sk_state_change(sk);
720 
721 	if (vsk->close_work_scheduled &&
722 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
723 		vsk->close_work_scheduled = false;
724 
725 		vsock_remove_sock(vsk);
726 
727 		/* Release refcnt obtained when we scheduled the timeout */
728 		sock_put(sk);
729 	}
730 }
731 
732 static void virtio_transport_close_timeout(struct work_struct *work)
733 {
734 	struct vsock_sock *vsk =
735 		container_of(work, struct vsock_sock, close_work.work);
736 	struct sock *sk = sk_vsock(vsk);
737 
738 	sock_hold(sk);
739 	lock_sock(sk);
740 
741 	if (!sock_flag(sk, SOCK_DONE)) {
742 		(void)virtio_transport_reset(vsk, NULL);
743 
744 		virtio_transport_do_close(vsk, false);
745 	}
746 
747 	vsk->close_work_scheduled = false;
748 
749 	release_sock(sk);
750 	sock_put(sk);
751 }
752 
753 /* User context, vsk->sk is locked */
754 static bool virtio_transport_close(struct vsock_sock *vsk)
755 {
756 	struct sock *sk = &vsk->sk;
757 
758 	if (!(sk->sk_state == TCP_ESTABLISHED ||
759 	      sk->sk_state == TCP_CLOSING))
760 		return true;
761 
762 	/* Already received SHUTDOWN from peer, reply with RST */
763 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
764 		(void)virtio_transport_reset(vsk, NULL);
765 		return true;
766 	}
767 
768 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
769 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
770 
771 	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
772 		virtio_transport_wait_close(sk, sk->sk_lingertime);
773 
774 	if (sock_flag(sk, SOCK_DONE)) {
775 		return true;
776 	}
777 
778 	sock_hold(sk);
779 	INIT_DELAYED_WORK(&vsk->close_work,
780 			  virtio_transport_close_timeout);
781 	vsk->close_work_scheduled = true;
782 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
783 	return false;
784 }
785 
786 void virtio_transport_release(struct vsock_sock *vsk)
787 {
788 	struct virtio_vsock_sock *vvs = vsk->trans;
789 	struct virtio_vsock_pkt *pkt, *tmp;
790 	struct sock *sk = &vsk->sk;
791 	bool remove_sock = true;
792 
793 	lock_sock(sk);
794 	if (sk->sk_type == SOCK_STREAM)
795 		remove_sock = virtio_transport_close(vsk);
796 
797 	list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
798 		list_del(&pkt->list);
799 		virtio_transport_free_pkt(pkt);
800 	}
801 	release_sock(sk);
802 
803 	if (remove_sock)
804 		vsock_remove_sock(vsk);
805 }
806 EXPORT_SYMBOL_GPL(virtio_transport_release);
807 
808 static int
809 virtio_transport_recv_connecting(struct sock *sk,
810 				 struct virtio_vsock_pkt *pkt)
811 {
812 	struct vsock_sock *vsk = vsock_sk(sk);
813 	int err;
814 	int skerr;
815 
816 	switch (le16_to_cpu(pkt->hdr.op)) {
817 	case VIRTIO_VSOCK_OP_RESPONSE:
818 		sk->sk_state = TCP_ESTABLISHED;
819 		sk->sk_socket->state = SS_CONNECTED;
820 		vsock_insert_connected(vsk);
821 		sk->sk_state_change(sk);
822 		break;
823 	case VIRTIO_VSOCK_OP_INVALID:
824 		break;
825 	case VIRTIO_VSOCK_OP_RST:
826 		skerr = ECONNRESET;
827 		err = 0;
828 		goto destroy;
829 	default:
830 		skerr = EPROTO;
831 		err = -EINVAL;
832 		goto destroy;
833 	}
834 	return 0;
835 
836 destroy:
837 	virtio_transport_reset(vsk, pkt);
838 	sk->sk_state = TCP_CLOSE;
839 	sk->sk_err = skerr;
840 	sk->sk_error_report(sk);
841 	return err;
842 }
843 
844 static int
845 virtio_transport_recv_connected(struct sock *sk,
846 				struct virtio_vsock_pkt *pkt)
847 {
848 	struct vsock_sock *vsk = vsock_sk(sk);
849 	struct virtio_vsock_sock *vvs = vsk->trans;
850 	int err = 0;
851 
852 	switch (le16_to_cpu(pkt->hdr.op)) {
853 	case VIRTIO_VSOCK_OP_RW:
854 		pkt->len = le32_to_cpu(pkt->hdr.len);
855 		pkt->off = 0;
856 
857 		spin_lock_bh(&vvs->rx_lock);
858 		virtio_transport_inc_rx_pkt(vvs, pkt);
859 		list_add_tail(&pkt->list, &vvs->rx_queue);
860 		spin_unlock_bh(&vvs->rx_lock);
861 
862 		sk->sk_data_ready(sk);
863 		return err;
864 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
865 		sk->sk_write_space(sk);
866 		break;
867 	case VIRTIO_VSOCK_OP_SHUTDOWN:
868 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
869 			vsk->peer_shutdown |= RCV_SHUTDOWN;
870 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
871 			vsk->peer_shutdown |= SEND_SHUTDOWN;
872 		if (vsk->peer_shutdown == SHUTDOWN_MASK &&
873 		    vsock_stream_has_data(vsk) <= 0) {
874 			sock_set_flag(sk, SOCK_DONE);
875 			sk->sk_state = TCP_CLOSING;
876 		}
877 		if (le32_to_cpu(pkt->hdr.flags))
878 			sk->sk_state_change(sk);
879 		break;
880 	case VIRTIO_VSOCK_OP_RST:
881 		virtio_transport_do_close(vsk, true);
882 		break;
883 	default:
884 		err = -EINVAL;
885 		break;
886 	}
887 
888 	virtio_transport_free_pkt(pkt);
889 	return err;
890 }
891 
892 static void
893 virtio_transport_recv_disconnecting(struct sock *sk,
894 				    struct virtio_vsock_pkt *pkt)
895 {
896 	struct vsock_sock *vsk = vsock_sk(sk);
897 
898 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
899 		virtio_transport_do_close(vsk, true);
900 }
901 
902 static int
903 virtio_transport_send_response(struct vsock_sock *vsk,
904 			       struct virtio_vsock_pkt *pkt)
905 {
906 	struct virtio_vsock_pkt_info info = {
907 		.op = VIRTIO_VSOCK_OP_RESPONSE,
908 		.type = VIRTIO_VSOCK_TYPE_STREAM,
909 		.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
910 		.remote_port = le32_to_cpu(pkt->hdr.src_port),
911 		.reply = true,
912 		.vsk = vsk,
913 	};
914 
915 	return virtio_transport_send_pkt_info(vsk, &info);
916 }
917 
918 /* Handle server socket */
919 static int
920 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
921 {
922 	struct vsock_sock *vsk = vsock_sk(sk);
923 	struct vsock_sock *vchild;
924 	struct sock *child;
925 
926 	if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
927 		virtio_transport_reset(vsk, pkt);
928 		return -EINVAL;
929 	}
930 
931 	if (sk_acceptq_is_full(sk)) {
932 		virtio_transport_reset(vsk, pkt);
933 		return -ENOMEM;
934 	}
935 
936 	child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
937 			       sk->sk_type, 0);
938 	if (!child) {
939 		virtio_transport_reset(vsk, pkt);
940 		return -ENOMEM;
941 	}
942 
943 	sk->sk_ack_backlog++;
944 
945 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
946 
947 	child->sk_state = TCP_ESTABLISHED;
948 
949 	vchild = vsock_sk(child);
950 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
951 			le32_to_cpu(pkt->hdr.dst_port));
952 	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
953 			le32_to_cpu(pkt->hdr.src_port));
954 
955 	vsock_insert_connected(vchild);
956 	vsock_enqueue_accept(sk, child);
957 	virtio_transport_send_response(vchild, pkt);
958 
959 	release_sock(child);
960 
961 	sk->sk_data_ready(sk);
962 	return 0;
963 }
964 
965 static bool virtio_transport_space_update(struct sock *sk,
966 					  struct virtio_vsock_pkt *pkt)
967 {
968 	struct vsock_sock *vsk = vsock_sk(sk);
969 	struct virtio_vsock_sock *vvs = vsk->trans;
970 	bool space_available;
971 
972 	/* buf_alloc and fwd_cnt is always included in the hdr */
973 	spin_lock_bh(&vvs->tx_lock);
974 	vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
975 	vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
976 	space_available = virtio_transport_has_space(vsk);
977 	spin_unlock_bh(&vvs->tx_lock);
978 	return space_available;
979 }
980 
981 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
982  * lock.
983  */
984 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
985 {
986 	struct sockaddr_vm src, dst;
987 	struct vsock_sock *vsk;
988 	struct sock *sk;
989 	bool space_available;
990 
991 	vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
992 			le32_to_cpu(pkt->hdr.src_port));
993 	vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
994 			le32_to_cpu(pkt->hdr.dst_port));
995 
996 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
997 					dst.svm_cid, dst.svm_port,
998 					le32_to_cpu(pkt->hdr.len),
999 					le16_to_cpu(pkt->hdr.type),
1000 					le16_to_cpu(pkt->hdr.op),
1001 					le32_to_cpu(pkt->hdr.flags),
1002 					le32_to_cpu(pkt->hdr.buf_alloc),
1003 					le32_to_cpu(pkt->hdr.fwd_cnt));
1004 
1005 	if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
1006 		(void)virtio_transport_reset_no_sock(pkt);
1007 		goto free_pkt;
1008 	}
1009 
1010 	/* The socket must be in connected or bound table
1011 	 * otherwise send reset back
1012 	 */
1013 	sk = vsock_find_connected_socket(&src, &dst);
1014 	if (!sk) {
1015 		sk = vsock_find_bound_socket(&dst);
1016 		if (!sk) {
1017 			(void)virtio_transport_reset_no_sock(pkt);
1018 			goto free_pkt;
1019 		}
1020 	}
1021 
1022 	vsk = vsock_sk(sk);
1023 
1024 	space_available = virtio_transport_space_update(sk, pkt);
1025 
1026 	lock_sock(sk);
1027 
1028 	/* Update CID in case it has changed after a transport reset event */
1029 	vsk->local_addr.svm_cid = dst.svm_cid;
1030 
1031 	if (space_available)
1032 		sk->sk_write_space(sk);
1033 
1034 	switch (sk->sk_state) {
1035 	case TCP_LISTEN:
1036 		virtio_transport_recv_listen(sk, pkt);
1037 		virtio_transport_free_pkt(pkt);
1038 		break;
1039 	case TCP_SYN_SENT:
1040 		virtio_transport_recv_connecting(sk, pkt);
1041 		virtio_transport_free_pkt(pkt);
1042 		break;
1043 	case TCP_ESTABLISHED:
1044 		virtio_transport_recv_connected(sk, pkt);
1045 		break;
1046 	case TCP_CLOSING:
1047 		virtio_transport_recv_disconnecting(sk, pkt);
1048 		virtio_transport_free_pkt(pkt);
1049 		break;
1050 	default:
1051 		virtio_transport_free_pkt(pkt);
1052 		break;
1053 	}
1054 	release_sock(sk);
1055 
1056 	/* Release refcnt obtained when we fetched this socket out of the
1057 	 * bound or connected list.
1058 	 */
1059 	sock_put(sk);
1060 	return;
1061 
1062 free_pkt:
1063 	virtio_transport_free_pkt(pkt);
1064 }
1065 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1066 
1067 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1068 {
1069 	kfree(pkt->buf);
1070 	kfree(pkt);
1071 }
1072 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1073 
1074 MODULE_LICENSE("GPL v2");
1075 MODULE_AUTHOR("Asias He");
1076 MODULE_DESCRIPTION("common code for virtio vsock");
1077