1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 #include <uapi/linux/vsockmon.h>
20 
21 #include <net/sock.h>
22 #include <net/af_vsock.h>
23 
24 #define CREATE_TRACE_POINTS
25 #include <trace/events/vsock_virtio_transport_common.h>
26 
27 /* How long to wait for graceful shutdown of a connection */
28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
29 
30 static const struct virtio_transport *virtio_transport_get_ops(void)
31 {
32 	const struct vsock_transport *t = vsock_core_get_transport();
33 
34 	return container_of(t, struct virtio_transport, transport);
35 }
36 
37 static struct virtio_vsock_pkt *
38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
39 			   size_t len,
40 			   u32 src_cid,
41 			   u32 src_port,
42 			   u32 dst_cid,
43 			   u32 dst_port)
44 {
45 	struct virtio_vsock_pkt *pkt;
46 	int err;
47 
48 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
49 	if (!pkt)
50 		return NULL;
51 
52 	pkt->hdr.type		= cpu_to_le16(info->type);
53 	pkt->hdr.op		= cpu_to_le16(info->op);
54 	pkt->hdr.src_cid	= cpu_to_le64(src_cid);
55 	pkt->hdr.dst_cid	= cpu_to_le64(dst_cid);
56 	pkt->hdr.src_port	= cpu_to_le32(src_port);
57 	pkt->hdr.dst_port	= cpu_to_le32(dst_port);
58 	pkt->hdr.flags		= cpu_to_le32(info->flags);
59 	pkt->len		= len;
60 	pkt->hdr.len		= cpu_to_le32(len);
61 	pkt->reply		= info->reply;
62 	pkt->vsk		= info->vsk;
63 
64 	if (info->msg && len > 0) {
65 		pkt->buf = kmalloc(len, GFP_KERNEL);
66 		if (!pkt->buf)
67 			goto out_pkt;
68 		err = memcpy_from_msg(pkt->buf, info->msg, len);
69 		if (err)
70 			goto out;
71 	}
72 
73 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
74 					 dst_cid, dst_port,
75 					 len,
76 					 info->type,
77 					 info->op,
78 					 info->flags);
79 
80 	return pkt;
81 
82 out:
83 	kfree(pkt->buf);
84 out_pkt:
85 	kfree(pkt);
86 	return NULL;
87 }
88 
89 /* Packet capture */
90 static struct sk_buff *virtio_transport_build_skb(void *opaque)
91 {
92 	struct virtio_vsock_pkt *pkt = opaque;
93 	struct af_vsockmon_hdr *hdr;
94 	struct sk_buff *skb;
95 
96 	skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
97 			GFP_ATOMIC);
98 	if (!skb)
99 		return NULL;
100 
101 	hdr = skb_put(skb, sizeof(*hdr));
102 
103 	/* pkt->hdr is little-endian so no need to byteswap here */
104 	hdr->src_cid = pkt->hdr.src_cid;
105 	hdr->src_port = pkt->hdr.src_port;
106 	hdr->dst_cid = pkt->hdr.dst_cid;
107 	hdr->dst_port = pkt->hdr.dst_port;
108 
109 	hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
110 	hdr->len = cpu_to_le16(sizeof(pkt->hdr));
111 	memset(hdr->reserved, 0, sizeof(hdr->reserved));
112 
113 	switch (le16_to_cpu(pkt->hdr.op)) {
114 	case VIRTIO_VSOCK_OP_REQUEST:
115 	case VIRTIO_VSOCK_OP_RESPONSE:
116 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
117 		break;
118 	case VIRTIO_VSOCK_OP_RST:
119 	case VIRTIO_VSOCK_OP_SHUTDOWN:
120 		hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
121 		break;
122 	case VIRTIO_VSOCK_OP_RW:
123 		hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
124 		break;
125 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
126 	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
127 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
128 		break;
129 	default:
130 		hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
131 		break;
132 	}
133 
134 	skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
135 
136 	if (pkt->len) {
137 		skb_put_data(skb, pkt->buf, pkt->len);
138 	}
139 
140 	return skb;
141 }
142 
143 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
144 {
145 	vsock_deliver_tap(virtio_transport_build_skb, pkt);
146 }
147 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
148 
149 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
150 					  struct virtio_vsock_pkt_info *info)
151 {
152 	u32 src_cid, src_port, dst_cid, dst_port;
153 	struct virtio_vsock_sock *vvs;
154 	struct virtio_vsock_pkt *pkt;
155 	u32 pkt_len = info->pkt_len;
156 
157 	src_cid = vm_sockets_get_local_cid();
158 	src_port = vsk->local_addr.svm_port;
159 	if (!info->remote_cid) {
160 		dst_cid	= vsk->remote_addr.svm_cid;
161 		dst_port = vsk->remote_addr.svm_port;
162 	} else {
163 		dst_cid = info->remote_cid;
164 		dst_port = info->remote_port;
165 	}
166 
167 	vvs = vsk->trans;
168 
169 	/* we can send less than pkt_len bytes */
170 	if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
171 		pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
172 
173 	/* virtio_transport_get_credit might return less than pkt_len credit */
174 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
175 
176 	/* Do not send zero length OP_RW pkt */
177 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
178 		return pkt_len;
179 
180 	pkt = virtio_transport_alloc_pkt(info, pkt_len,
181 					 src_cid, src_port,
182 					 dst_cid, dst_port);
183 	if (!pkt) {
184 		virtio_transport_put_credit(vvs, pkt_len);
185 		return -ENOMEM;
186 	}
187 
188 	virtio_transport_inc_tx_pkt(vvs, pkt);
189 
190 	return virtio_transport_get_ops()->send_pkt(pkt);
191 }
192 
193 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
194 					struct virtio_vsock_pkt *pkt)
195 {
196 	vvs->rx_bytes += pkt->len;
197 }
198 
199 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
200 					struct virtio_vsock_pkt *pkt)
201 {
202 	vvs->rx_bytes -= pkt->len;
203 	vvs->fwd_cnt += pkt->len;
204 }
205 
206 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
207 {
208 	spin_lock_bh(&vvs->tx_lock);
209 	pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
210 	pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
211 	spin_unlock_bh(&vvs->tx_lock);
212 }
213 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
214 
215 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
216 {
217 	u32 ret;
218 
219 	spin_lock_bh(&vvs->tx_lock);
220 	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
221 	if (ret > credit)
222 		ret = credit;
223 	vvs->tx_cnt += ret;
224 	spin_unlock_bh(&vvs->tx_lock);
225 
226 	return ret;
227 }
228 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
229 
230 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
231 {
232 	spin_lock_bh(&vvs->tx_lock);
233 	vvs->tx_cnt -= credit;
234 	spin_unlock_bh(&vvs->tx_lock);
235 }
236 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
237 
238 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
239 					       int type,
240 					       struct virtio_vsock_hdr *hdr)
241 {
242 	struct virtio_vsock_pkt_info info = {
243 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
244 		.type = type,
245 		.vsk = vsk,
246 	};
247 
248 	return virtio_transport_send_pkt_info(vsk, &info);
249 }
250 
251 static ssize_t
252 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
253 				   struct msghdr *msg,
254 				   size_t len)
255 {
256 	struct virtio_vsock_sock *vvs = vsk->trans;
257 	struct virtio_vsock_pkt *pkt;
258 	size_t bytes, total = 0;
259 	int err = -EFAULT;
260 
261 	spin_lock_bh(&vvs->rx_lock);
262 	while (total < len && !list_empty(&vvs->rx_queue)) {
263 		pkt = list_first_entry(&vvs->rx_queue,
264 				       struct virtio_vsock_pkt, list);
265 
266 		bytes = len - total;
267 		if (bytes > pkt->len - pkt->off)
268 			bytes = pkt->len - pkt->off;
269 
270 		/* sk_lock is held by caller so no one else can dequeue.
271 		 * Unlock rx_lock since memcpy_to_msg() may sleep.
272 		 */
273 		spin_unlock_bh(&vvs->rx_lock);
274 
275 		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
276 		if (err)
277 			goto out;
278 
279 		spin_lock_bh(&vvs->rx_lock);
280 
281 		total += bytes;
282 		pkt->off += bytes;
283 		if (pkt->off == pkt->len) {
284 			virtio_transport_dec_rx_pkt(vvs, pkt);
285 			list_del(&pkt->list);
286 			virtio_transport_free_pkt(pkt);
287 		}
288 	}
289 	spin_unlock_bh(&vvs->rx_lock);
290 
291 	/* Send a credit pkt to peer */
292 	virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
293 					    NULL);
294 
295 	return total;
296 
297 out:
298 	if (total)
299 		err = total;
300 	return err;
301 }
302 
303 ssize_t
304 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
305 				struct msghdr *msg,
306 				size_t len, int flags)
307 {
308 	if (flags & MSG_PEEK)
309 		return -EOPNOTSUPP;
310 
311 	return virtio_transport_stream_do_dequeue(vsk, msg, len);
312 }
313 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
314 
315 int
316 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
317 			       struct msghdr *msg,
318 			       size_t len, int flags)
319 {
320 	return -EOPNOTSUPP;
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
323 
324 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
325 {
326 	struct virtio_vsock_sock *vvs = vsk->trans;
327 	s64 bytes;
328 
329 	spin_lock_bh(&vvs->rx_lock);
330 	bytes = vvs->rx_bytes;
331 	spin_unlock_bh(&vvs->rx_lock);
332 
333 	return bytes;
334 }
335 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
336 
337 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
338 {
339 	struct virtio_vsock_sock *vvs = vsk->trans;
340 	s64 bytes;
341 
342 	bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
343 	if (bytes < 0)
344 		bytes = 0;
345 
346 	return bytes;
347 }
348 
349 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
350 {
351 	struct virtio_vsock_sock *vvs = vsk->trans;
352 	s64 bytes;
353 
354 	spin_lock_bh(&vvs->tx_lock);
355 	bytes = virtio_transport_has_space(vsk);
356 	spin_unlock_bh(&vvs->tx_lock);
357 
358 	return bytes;
359 }
360 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
361 
362 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
363 				    struct vsock_sock *psk)
364 {
365 	struct virtio_vsock_sock *vvs;
366 
367 	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
368 	if (!vvs)
369 		return -ENOMEM;
370 
371 	vsk->trans = vvs;
372 	vvs->vsk = vsk;
373 	if (psk) {
374 		struct virtio_vsock_sock *ptrans = psk->trans;
375 
376 		vvs->buf_size	= ptrans->buf_size;
377 		vvs->buf_size_min = ptrans->buf_size_min;
378 		vvs->buf_size_max = ptrans->buf_size_max;
379 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
380 	} else {
381 		vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
382 		vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
383 		vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
384 	}
385 
386 	vvs->buf_alloc = vvs->buf_size;
387 
388 	spin_lock_init(&vvs->rx_lock);
389 	spin_lock_init(&vvs->tx_lock);
390 	INIT_LIST_HEAD(&vvs->rx_queue);
391 
392 	return 0;
393 }
394 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
395 
396 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
397 {
398 	struct virtio_vsock_sock *vvs = vsk->trans;
399 
400 	return vvs->buf_size;
401 }
402 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
403 
404 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
405 {
406 	struct virtio_vsock_sock *vvs = vsk->trans;
407 
408 	return vvs->buf_size_min;
409 }
410 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
411 
412 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
413 {
414 	struct virtio_vsock_sock *vvs = vsk->trans;
415 
416 	return vvs->buf_size_max;
417 }
418 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
419 
420 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
421 {
422 	struct virtio_vsock_sock *vvs = vsk->trans;
423 
424 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
425 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
426 	if (val < vvs->buf_size_min)
427 		vvs->buf_size_min = val;
428 	if (val > vvs->buf_size_max)
429 		vvs->buf_size_max = val;
430 	vvs->buf_size = val;
431 	vvs->buf_alloc = val;
432 }
433 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
434 
435 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
436 {
437 	struct virtio_vsock_sock *vvs = vsk->trans;
438 
439 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
440 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
441 	if (val > vvs->buf_size)
442 		vvs->buf_size = val;
443 	vvs->buf_size_min = val;
444 }
445 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
446 
447 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
448 {
449 	struct virtio_vsock_sock *vvs = vsk->trans;
450 
451 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
452 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
453 	if (val < vvs->buf_size)
454 		vvs->buf_size = val;
455 	vvs->buf_size_max = val;
456 }
457 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
458 
459 int
460 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
461 				size_t target,
462 				bool *data_ready_now)
463 {
464 	if (vsock_stream_has_data(vsk))
465 		*data_ready_now = true;
466 	else
467 		*data_ready_now = false;
468 
469 	return 0;
470 }
471 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
472 
473 int
474 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
475 				 size_t target,
476 				 bool *space_avail_now)
477 {
478 	s64 free_space;
479 
480 	free_space = vsock_stream_has_space(vsk);
481 	if (free_space > 0)
482 		*space_avail_now = true;
483 	else if (free_space == 0)
484 		*space_avail_now = false;
485 
486 	return 0;
487 }
488 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
489 
490 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
491 	size_t target, struct vsock_transport_recv_notify_data *data)
492 {
493 	return 0;
494 }
495 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
496 
497 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
498 	size_t target, struct vsock_transport_recv_notify_data *data)
499 {
500 	return 0;
501 }
502 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
503 
504 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
505 	size_t target, struct vsock_transport_recv_notify_data *data)
506 {
507 	return 0;
508 }
509 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
510 
511 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
512 	size_t target, ssize_t copied, bool data_read,
513 	struct vsock_transport_recv_notify_data *data)
514 {
515 	return 0;
516 }
517 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
518 
519 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
520 	struct vsock_transport_send_notify_data *data)
521 {
522 	return 0;
523 }
524 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
525 
526 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
527 	struct vsock_transport_send_notify_data *data)
528 {
529 	return 0;
530 }
531 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
532 
533 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
534 	struct vsock_transport_send_notify_data *data)
535 {
536 	return 0;
537 }
538 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
539 
540 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
541 	ssize_t written, struct vsock_transport_send_notify_data *data)
542 {
543 	return 0;
544 }
545 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
546 
547 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
548 {
549 	struct virtio_vsock_sock *vvs = vsk->trans;
550 
551 	return vvs->buf_size;
552 }
553 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
554 
555 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
556 {
557 	return true;
558 }
559 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
560 
561 bool virtio_transport_stream_allow(u32 cid, u32 port)
562 {
563 	return true;
564 }
565 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
566 
567 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
568 				struct sockaddr_vm *addr)
569 {
570 	return -EOPNOTSUPP;
571 }
572 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
573 
574 bool virtio_transport_dgram_allow(u32 cid, u32 port)
575 {
576 	return false;
577 }
578 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
579 
580 int virtio_transport_connect(struct vsock_sock *vsk)
581 {
582 	struct virtio_vsock_pkt_info info = {
583 		.op = VIRTIO_VSOCK_OP_REQUEST,
584 		.type = VIRTIO_VSOCK_TYPE_STREAM,
585 		.vsk = vsk,
586 	};
587 
588 	return virtio_transport_send_pkt_info(vsk, &info);
589 }
590 EXPORT_SYMBOL_GPL(virtio_transport_connect);
591 
592 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
593 {
594 	struct virtio_vsock_pkt_info info = {
595 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
596 		.type = VIRTIO_VSOCK_TYPE_STREAM,
597 		.flags = (mode & RCV_SHUTDOWN ?
598 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
599 			 (mode & SEND_SHUTDOWN ?
600 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
601 		.vsk = vsk,
602 	};
603 
604 	return virtio_transport_send_pkt_info(vsk, &info);
605 }
606 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
607 
608 int
609 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
610 			       struct sockaddr_vm *remote_addr,
611 			       struct msghdr *msg,
612 			       size_t dgram_len)
613 {
614 	return -EOPNOTSUPP;
615 }
616 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
617 
618 ssize_t
619 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
620 				struct msghdr *msg,
621 				size_t len)
622 {
623 	struct virtio_vsock_pkt_info info = {
624 		.op = VIRTIO_VSOCK_OP_RW,
625 		.type = VIRTIO_VSOCK_TYPE_STREAM,
626 		.msg = msg,
627 		.pkt_len = len,
628 		.vsk = vsk,
629 	};
630 
631 	return virtio_transport_send_pkt_info(vsk, &info);
632 }
633 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
634 
635 void virtio_transport_destruct(struct vsock_sock *vsk)
636 {
637 	struct virtio_vsock_sock *vvs = vsk->trans;
638 
639 	kfree(vvs);
640 }
641 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
642 
643 static int virtio_transport_reset(struct vsock_sock *vsk,
644 				  struct virtio_vsock_pkt *pkt)
645 {
646 	struct virtio_vsock_pkt_info info = {
647 		.op = VIRTIO_VSOCK_OP_RST,
648 		.type = VIRTIO_VSOCK_TYPE_STREAM,
649 		.reply = !!pkt,
650 		.vsk = vsk,
651 	};
652 
653 	/* Send RST only if the original pkt is not a RST pkt */
654 	if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
655 		return 0;
656 
657 	return virtio_transport_send_pkt_info(vsk, &info);
658 }
659 
660 /* Normally packets are associated with a socket.  There may be no socket if an
661  * attempt was made to connect to a socket that does not exist.
662  */
663 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
664 {
665 	const struct virtio_transport *t;
666 	struct virtio_vsock_pkt *reply;
667 	struct virtio_vsock_pkt_info info = {
668 		.op = VIRTIO_VSOCK_OP_RST,
669 		.type = le16_to_cpu(pkt->hdr.type),
670 		.reply = true,
671 	};
672 
673 	/* Send RST only if the original pkt is not a RST pkt */
674 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
675 		return 0;
676 
677 	reply = virtio_transport_alloc_pkt(&info, 0,
678 					   le64_to_cpu(pkt->hdr.dst_cid),
679 					   le32_to_cpu(pkt->hdr.dst_port),
680 					   le64_to_cpu(pkt->hdr.src_cid),
681 					   le32_to_cpu(pkt->hdr.src_port));
682 	if (!reply)
683 		return -ENOMEM;
684 
685 	t = virtio_transport_get_ops();
686 	if (!t) {
687 		virtio_transport_free_pkt(reply);
688 		return -ENOTCONN;
689 	}
690 
691 	return t->send_pkt(reply);
692 }
693 
694 static void virtio_transport_wait_close(struct sock *sk, long timeout)
695 {
696 	if (timeout) {
697 		DEFINE_WAIT_FUNC(wait, woken_wake_function);
698 
699 		add_wait_queue(sk_sleep(sk), &wait);
700 
701 		do {
702 			if (sk_wait_event(sk, &timeout,
703 					  sock_flag(sk, SOCK_DONE), &wait))
704 				break;
705 		} while (!signal_pending(current) && timeout);
706 
707 		remove_wait_queue(sk_sleep(sk), &wait);
708 	}
709 }
710 
711 static void virtio_transport_do_close(struct vsock_sock *vsk,
712 				      bool cancel_timeout)
713 {
714 	struct sock *sk = sk_vsock(vsk);
715 
716 	sock_set_flag(sk, SOCK_DONE);
717 	vsk->peer_shutdown = SHUTDOWN_MASK;
718 	if (vsock_stream_has_data(vsk) <= 0)
719 		sk->sk_state = TCP_CLOSING;
720 	sk->sk_state_change(sk);
721 
722 	if (vsk->close_work_scheduled &&
723 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
724 		vsk->close_work_scheduled = false;
725 
726 		vsock_remove_sock(vsk);
727 
728 		/* Release refcnt obtained when we scheduled the timeout */
729 		sock_put(sk);
730 	}
731 }
732 
733 static void virtio_transport_close_timeout(struct work_struct *work)
734 {
735 	struct vsock_sock *vsk =
736 		container_of(work, struct vsock_sock, close_work.work);
737 	struct sock *sk = sk_vsock(vsk);
738 
739 	sock_hold(sk);
740 	lock_sock(sk);
741 
742 	if (!sock_flag(sk, SOCK_DONE)) {
743 		(void)virtio_transport_reset(vsk, NULL);
744 
745 		virtio_transport_do_close(vsk, false);
746 	}
747 
748 	vsk->close_work_scheduled = false;
749 
750 	release_sock(sk);
751 	sock_put(sk);
752 }
753 
754 /* User context, vsk->sk is locked */
755 static bool virtio_transport_close(struct vsock_sock *vsk)
756 {
757 	struct sock *sk = &vsk->sk;
758 
759 	if (!(sk->sk_state == TCP_ESTABLISHED ||
760 	      sk->sk_state == TCP_CLOSING))
761 		return true;
762 
763 	/* Already received SHUTDOWN from peer, reply with RST */
764 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
765 		(void)virtio_transport_reset(vsk, NULL);
766 		return true;
767 	}
768 
769 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
770 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
771 
772 	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
773 		virtio_transport_wait_close(sk, sk->sk_lingertime);
774 
775 	if (sock_flag(sk, SOCK_DONE)) {
776 		return true;
777 	}
778 
779 	sock_hold(sk);
780 	INIT_DELAYED_WORK(&vsk->close_work,
781 			  virtio_transport_close_timeout);
782 	vsk->close_work_scheduled = true;
783 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
784 	return false;
785 }
786 
787 void virtio_transport_release(struct vsock_sock *vsk)
788 {
789 	struct sock *sk = &vsk->sk;
790 	bool remove_sock = true;
791 
792 	lock_sock(sk);
793 	if (sk->sk_type == SOCK_STREAM)
794 		remove_sock = virtio_transport_close(vsk);
795 	release_sock(sk);
796 
797 	if (remove_sock)
798 		vsock_remove_sock(vsk);
799 }
800 EXPORT_SYMBOL_GPL(virtio_transport_release);
801 
802 static int
803 virtio_transport_recv_connecting(struct sock *sk,
804 				 struct virtio_vsock_pkt *pkt)
805 {
806 	struct vsock_sock *vsk = vsock_sk(sk);
807 	int err;
808 	int skerr;
809 
810 	switch (le16_to_cpu(pkt->hdr.op)) {
811 	case VIRTIO_VSOCK_OP_RESPONSE:
812 		sk->sk_state = TCP_ESTABLISHED;
813 		sk->sk_socket->state = SS_CONNECTED;
814 		vsock_insert_connected(vsk);
815 		sk->sk_state_change(sk);
816 		break;
817 	case VIRTIO_VSOCK_OP_INVALID:
818 		break;
819 	case VIRTIO_VSOCK_OP_RST:
820 		skerr = ECONNRESET;
821 		err = 0;
822 		goto destroy;
823 	default:
824 		skerr = EPROTO;
825 		err = -EINVAL;
826 		goto destroy;
827 	}
828 	return 0;
829 
830 destroy:
831 	virtio_transport_reset(vsk, pkt);
832 	sk->sk_state = TCP_CLOSE;
833 	sk->sk_err = skerr;
834 	sk->sk_error_report(sk);
835 	return err;
836 }
837 
838 static int
839 virtio_transport_recv_connected(struct sock *sk,
840 				struct virtio_vsock_pkt *pkt)
841 {
842 	struct vsock_sock *vsk = vsock_sk(sk);
843 	struct virtio_vsock_sock *vvs = vsk->trans;
844 	int err = 0;
845 
846 	switch (le16_to_cpu(pkt->hdr.op)) {
847 	case VIRTIO_VSOCK_OP_RW:
848 		pkt->len = le32_to_cpu(pkt->hdr.len);
849 		pkt->off = 0;
850 
851 		spin_lock_bh(&vvs->rx_lock);
852 		virtio_transport_inc_rx_pkt(vvs, pkt);
853 		list_add_tail(&pkt->list, &vvs->rx_queue);
854 		spin_unlock_bh(&vvs->rx_lock);
855 
856 		sk->sk_data_ready(sk);
857 		return err;
858 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
859 		sk->sk_write_space(sk);
860 		break;
861 	case VIRTIO_VSOCK_OP_SHUTDOWN:
862 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
863 			vsk->peer_shutdown |= RCV_SHUTDOWN;
864 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
865 			vsk->peer_shutdown |= SEND_SHUTDOWN;
866 		if (vsk->peer_shutdown == SHUTDOWN_MASK &&
867 		    vsock_stream_has_data(vsk) <= 0)
868 			sk->sk_state = TCP_CLOSING;
869 		if (le32_to_cpu(pkt->hdr.flags))
870 			sk->sk_state_change(sk);
871 		break;
872 	case VIRTIO_VSOCK_OP_RST:
873 		virtio_transport_do_close(vsk, true);
874 		break;
875 	default:
876 		err = -EINVAL;
877 		break;
878 	}
879 
880 	virtio_transport_free_pkt(pkt);
881 	return err;
882 }
883 
884 static void
885 virtio_transport_recv_disconnecting(struct sock *sk,
886 				    struct virtio_vsock_pkt *pkt)
887 {
888 	struct vsock_sock *vsk = vsock_sk(sk);
889 
890 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
891 		virtio_transport_do_close(vsk, true);
892 }
893 
894 static int
895 virtio_transport_send_response(struct vsock_sock *vsk,
896 			       struct virtio_vsock_pkt *pkt)
897 {
898 	struct virtio_vsock_pkt_info info = {
899 		.op = VIRTIO_VSOCK_OP_RESPONSE,
900 		.type = VIRTIO_VSOCK_TYPE_STREAM,
901 		.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
902 		.remote_port = le32_to_cpu(pkt->hdr.src_port),
903 		.reply = true,
904 		.vsk = vsk,
905 	};
906 
907 	return virtio_transport_send_pkt_info(vsk, &info);
908 }
909 
910 /* Handle server socket */
911 static int
912 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
913 {
914 	struct vsock_sock *vsk = vsock_sk(sk);
915 	struct vsock_sock *vchild;
916 	struct sock *child;
917 
918 	if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
919 		virtio_transport_reset(vsk, pkt);
920 		return -EINVAL;
921 	}
922 
923 	if (sk_acceptq_is_full(sk)) {
924 		virtio_transport_reset(vsk, pkt);
925 		return -ENOMEM;
926 	}
927 
928 	child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
929 			       sk->sk_type, 0);
930 	if (!child) {
931 		virtio_transport_reset(vsk, pkt);
932 		return -ENOMEM;
933 	}
934 
935 	sk->sk_ack_backlog++;
936 
937 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
938 
939 	child->sk_state = TCP_ESTABLISHED;
940 
941 	vchild = vsock_sk(child);
942 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
943 			le32_to_cpu(pkt->hdr.dst_port));
944 	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
945 			le32_to_cpu(pkt->hdr.src_port));
946 
947 	vsock_insert_connected(vchild);
948 	vsock_enqueue_accept(sk, child);
949 	virtio_transport_send_response(vchild, pkt);
950 
951 	release_sock(child);
952 
953 	sk->sk_data_ready(sk);
954 	return 0;
955 }
956 
957 static bool virtio_transport_space_update(struct sock *sk,
958 					  struct virtio_vsock_pkt *pkt)
959 {
960 	struct vsock_sock *vsk = vsock_sk(sk);
961 	struct virtio_vsock_sock *vvs = vsk->trans;
962 	bool space_available;
963 
964 	/* buf_alloc and fwd_cnt is always included in the hdr */
965 	spin_lock_bh(&vvs->tx_lock);
966 	vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
967 	vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
968 	space_available = virtio_transport_has_space(vsk);
969 	spin_unlock_bh(&vvs->tx_lock);
970 	return space_available;
971 }
972 
973 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
974  * lock.
975  */
976 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
977 {
978 	struct sockaddr_vm src, dst;
979 	struct vsock_sock *vsk;
980 	struct sock *sk;
981 	bool space_available;
982 
983 	vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
984 			le32_to_cpu(pkt->hdr.src_port));
985 	vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
986 			le32_to_cpu(pkt->hdr.dst_port));
987 
988 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
989 					dst.svm_cid, dst.svm_port,
990 					le32_to_cpu(pkt->hdr.len),
991 					le16_to_cpu(pkt->hdr.type),
992 					le16_to_cpu(pkt->hdr.op),
993 					le32_to_cpu(pkt->hdr.flags),
994 					le32_to_cpu(pkt->hdr.buf_alloc),
995 					le32_to_cpu(pkt->hdr.fwd_cnt));
996 
997 	if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
998 		(void)virtio_transport_reset_no_sock(pkt);
999 		goto free_pkt;
1000 	}
1001 
1002 	/* The socket must be in connected or bound table
1003 	 * otherwise send reset back
1004 	 */
1005 	sk = vsock_find_connected_socket(&src, &dst);
1006 	if (!sk) {
1007 		sk = vsock_find_bound_socket(&dst);
1008 		if (!sk) {
1009 			(void)virtio_transport_reset_no_sock(pkt);
1010 			goto free_pkt;
1011 		}
1012 	}
1013 
1014 	vsk = vsock_sk(sk);
1015 
1016 	space_available = virtio_transport_space_update(sk, pkt);
1017 
1018 	lock_sock(sk);
1019 
1020 	/* Update CID in case it has changed after a transport reset event */
1021 	vsk->local_addr.svm_cid = dst.svm_cid;
1022 
1023 	if (space_available)
1024 		sk->sk_write_space(sk);
1025 
1026 	switch (sk->sk_state) {
1027 	case TCP_LISTEN:
1028 		virtio_transport_recv_listen(sk, pkt);
1029 		virtio_transport_free_pkt(pkt);
1030 		break;
1031 	case TCP_SYN_SENT:
1032 		virtio_transport_recv_connecting(sk, pkt);
1033 		virtio_transport_free_pkt(pkt);
1034 		break;
1035 	case TCP_ESTABLISHED:
1036 		virtio_transport_recv_connected(sk, pkt);
1037 		break;
1038 	case TCP_CLOSING:
1039 		virtio_transport_recv_disconnecting(sk, pkt);
1040 		virtio_transport_free_pkt(pkt);
1041 		break;
1042 	default:
1043 		virtio_transport_free_pkt(pkt);
1044 		break;
1045 	}
1046 	release_sock(sk);
1047 
1048 	/* Release refcnt obtained when we fetched this socket out of the
1049 	 * bound or connected list.
1050 	 */
1051 	sock_put(sk);
1052 	return;
1053 
1054 free_pkt:
1055 	virtio_transport_free_pkt(pkt);
1056 }
1057 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1058 
1059 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1060 {
1061 	kfree(pkt->buf);
1062 	kfree(pkt);
1063 }
1064 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1065 
1066 MODULE_LICENSE("GPL v2");
1067 MODULE_AUTHOR("Asias He");
1068 MODULE_DESCRIPTION("common code for virtio vsock");
1069