1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 #include <uapi/linux/vsockmon.h>
20 
21 #include <net/sock.h>
22 #include <net/af_vsock.h>
23 
24 #define CREATE_TRACE_POINTS
25 #include <trace/events/vsock_virtio_transport_common.h>
26 
27 /* How long to wait for graceful shutdown of a connection */
28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
29 
30 static const struct virtio_transport *virtio_transport_get_ops(void)
31 {
32 	const struct vsock_transport *t = vsock_core_get_transport();
33 
34 	return container_of(t, struct virtio_transport, transport);
35 }
36 
37 static struct virtio_vsock_pkt *
38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
39 			   size_t len,
40 			   u32 src_cid,
41 			   u32 src_port,
42 			   u32 dst_cid,
43 			   u32 dst_port)
44 {
45 	struct virtio_vsock_pkt *pkt;
46 	int err;
47 
48 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
49 	if (!pkt)
50 		return NULL;
51 
52 	pkt->hdr.type		= cpu_to_le16(info->type);
53 	pkt->hdr.op		= cpu_to_le16(info->op);
54 	pkt->hdr.src_cid	= cpu_to_le64(src_cid);
55 	pkt->hdr.dst_cid	= cpu_to_le64(dst_cid);
56 	pkt->hdr.src_port	= cpu_to_le32(src_port);
57 	pkt->hdr.dst_port	= cpu_to_le32(dst_port);
58 	pkt->hdr.flags		= cpu_to_le32(info->flags);
59 	pkt->len		= len;
60 	pkt->hdr.len		= cpu_to_le32(len);
61 	pkt->reply		= info->reply;
62 	pkt->vsk		= info->vsk;
63 
64 	if (info->msg && len > 0) {
65 		pkt->buf = kmalloc(len, GFP_KERNEL);
66 		if (!pkt->buf)
67 			goto out_pkt;
68 		err = memcpy_from_msg(pkt->buf, info->msg, len);
69 		if (err)
70 			goto out;
71 	}
72 
73 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
74 					 dst_cid, dst_port,
75 					 len,
76 					 info->type,
77 					 info->op,
78 					 info->flags);
79 
80 	return pkt;
81 
82 out:
83 	kfree(pkt->buf);
84 out_pkt:
85 	kfree(pkt);
86 	return NULL;
87 }
88 
89 /* Packet capture */
90 static struct sk_buff *virtio_transport_build_skb(void *opaque)
91 {
92 	struct virtio_vsock_pkt *pkt = opaque;
93 	struct af_vsockmon_hdr *hdr;
94 	struct sk_buff *skb;
95 
96 	skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
97 			GFP_ATOMIC);
98 	if (!skb)
99 		return NULL;
100 
101 	hdr = skb_put(skb, sizeof(*hdr));
102 
103 	/* pkt->hdr is little-endian so no need to byteswap here */
104 	hdr->src_cid = pkt->hdr.src_cid;
105 	hdr->src_port = pkt->hdr.src_port;
106 	hdr->dst_cid = pkt->hdr.dst_cid;
107 	hdr->dst_port = pkt->hdr.dst_port;
108 
109 	hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
110 	hdr->len = cpu_to_le16(sizeof(pkt->hdr));
111 	memset(hdr->reserved, 0, sizeof(hdr->reserved));
112 
113 	switch (le16_to_cpu(pkt->hdr.op)) {
114 	case VIRTIO_VSOCK_OP_REQUEST:
115 	case VIRTIO_VSOCK_OP_RESPONSE:
116 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
117 		break;
118 	case VIRTIO_VSOCK_OP_RST:
119 	case VIRTIO_VSOCK_OP_SHUTDOWN:
120 		hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
121 		break;
122 	case VIRTIO_VSOCK_OP_RW:
123 		hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
124 		break;
125 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
126 	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
127 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
128 		break;
129 	default:
130 		hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
131 		break;
132 	}
133 
134 	skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
135 
136 	if (pkt->len) {
137 		skb_put_data(skb, pkt->buf, pkt->len);
138 	}
139 
140 	return skb;
141 }
142 
143 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
144 {
145 	vsock_deliver_tap(virtio_transport_build_skb, pkt);
146 }
147 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
148 
149 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
150 					  struct virtio_vsock_pkt_info *info)
151 {
152 	u32 src_cid, src_port, dst_cid, dst_port;
153 	struct virtio_vsock_sock *vvs;
154 	struct virtio_vsock_pkt *pkt;
155 	u32 pkt_len = info->pkt_len;
156 
157 	src_cid = vm_sockets_get_local_cid();
158 	src_port = vsk->local_addr.svm_port;
159 	if (!info->remote_cid) {
160 		dst_cid	= vsk->remote_addr.svm_cid;
161 		dst_port = vsk->remote_addr.svm_port;
162 	} else {
163 		dst_cid = info->remote_cid;
164 		dst_port = info->remote_port;
165 	}
166 
167 	vvs = vsk->trans;
168 
169 	/* we can send less than pkt_len bytes */
170 	if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
171 		pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
172 
173 	/* virtio_transport_get_credit might return less than pkt_len credit */
174 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
175 
176 	/* Do not send zero length OP_RW pkt */
177 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
178 		return pkt_len;
179 
180 	pkt = virtio_transport_alloc_pkt(info, pkt_len,
181 					 src_cid, src_port,
182 					 dst_cid, dst_port);
183 	if (!pkt) {
184 		virtio_transport_put_credit(vvs, pkt_len);
185 		return -ENOMEM;
186 	}
187 
188 	virtio_transport_inc_tx_pkt(vvs, pkt);
189 
190 	return virtio_transport_get_ops()->send_pkt(pkt);
191 }
192 
193 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
194 					struct virtio_vsock_pkt *pkt)
195 {
196 	vvs->rx_bytes += pkt->len;
197 }
198 
199 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
200 					struct virtio_vsock_pkt *pkt)
201 {
202 	vvs->rx_bytes -= pkt->len;
203 	vvs->fwd_cnt += pkt->len;
204 }
205 
206 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
207 {
208 	spin_lock_bh(&vvs->tx_lock);
209 	pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
210 	pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
211 	spin_unlock_bh(&vvs->tx_lock);
212 }
213 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
214 
215 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
216 {
217 	u32 ret;
218 
219 	spin_lock_bh(&vvs->tx_lock);
220 	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
221 	if (ret > credit)
222 		ret = credit;
223 	vvs->tx_cnt += ret;
224 	spin_unlock_bh(&vvs->tx_lock);
225 
226 	return ret;
227 }
228 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
229 
230 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
231 {
232 	spin_lock_bh(&vvs->tx_lock);
233 	vvs->tx_cnt -= credit;
234 	spin_unlock_bh(&vvs->tx_lock);
235 }
236 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
237 
238 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
239 					       int type,
240 					       struct virtio_vsock_hdr *hdr)
241 {
242 	struct virtio_vsock_pkt_info info = {
243 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
244 		.type = type,
245 		.vsk = vsk,
246 	};
247 
248 	return virtio_transport_send_pkt_info(vsk, &info);
249 }
250 
251 static ssize_t
252 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
253 				   struct msghdr *msg,
254 				   size_t len)
255 {
256 	struct virtio_vsock_sock *vvs = vsk->trans;
257 	struct virtio_vsock_pkt *pkt;
258 	size_t bytes, total = 0;
259 	int err = -EFAULT;
260 
261 	spin_lock_bh(&vvs->rx_lock);
262 	while (total < len && !list_empty(&vvs->rx_queue)) {
263 		pkt = list_first_entry(&vvs->rx_queue,
264 				       struct virtio_vsock_pkt, list);
265 
266 		bytes = len - total;
267 		if (bytes > pkt->len - pkt->off)
268 			bytes = pkt->len - pkt->off;
269 
270 		/* sk_lock is held by caller so no one else can dequeue.
271 		 * Unlock rx_lock since memcpy_to_msg() may sleep.
272 		 */
273 		spin_unlock_bh(&vvs->rx_lock);
274 
275 		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
276 		if (err)
277 			goto out;
278 
279 		spin_lock_bh(&vvs->rx_lock);
280 
281 		total += bytes;
282 		pkt->off += bytes;
283 		if (pkt->off == pkt->len) {
284 			virtio_transport_dec_rx_pkt(vvs, pkt);
285 			list_del(&pkt->list);
286 			virtio_transport_free_pkt(pkt);
287 		}
288 	}
289 	spin_unlock_bh(&vvs->rx_lock);
290 
291 	/* Send a credit pkt to peer */
292 	virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
293 					    NULL);
294 
295 	return total;
296 
297 out:
298 	if (total)
299 		err = total;
300 	return err;
301 }
302 
303 ssize_t
304 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
305 				struct msghdr *msg,
306 				size_t len, int flags)
307 {
308 	if (flags & MSG_PEEK)
309 		return -EOPNOTSUPP;
310 
311 	return virtio_transport_stream_do_dequeue(vsk, msg, len);
312 }
313 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
314 
315 int
316 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
317 			       struct msghdr *msg,
318 			       size_t len, int flags)
319 {
320 	return -EOPNOTSUPP;
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
323 
324 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
325 {
326 	struct virtio_vsock_sock *vvs = vsk->trans;
327 	s64 bytes;
328 
329 	spin_lock_bh(&vvs->rx_lock);
330 	bytes = vvs->rx_bytes;
331 	spin_unlock_bh(&vvs->rx_lock);
332 
333 	return bytes;
334 }
335 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
336 
337 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
338 {
339 	struct virtio_vsock_sock *vvs = vsk->trans;
340 	s64 bytes;
341 
342 	bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
343 	if (bytes < 0)
344 		bytes = 0;
345 
346 	return bytes;
347 }
348 
349 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
350 {
351 	struct virtio_vsock_sock *vvs = vsk->trans;
352 	s64 bytes;
353 
354 	spin_lock_bh(&vvs->tx_lock);
355 	bytes = virtio_transport_has_space(vsk);
356 	spin_unlock_bh(&vvs->tx_lock);
357 
358 	return bytes;
359 }
360 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
361 
362 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
363 				    struct vsock_sock *psk)
364 {
365 	struct virtio_vsock_sock *vvs;
366 
367 	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
368 	if (!vvs)
369 		return -ENOMEM;
370 
371 	vsk->trans = vvs;
372 	vvs->vsk = vsk;
373 	if (psk) {
374 		struct virtio_vsock_sock *ptrans = psk->trans;
375 
376 		vvs->buf_size	= ptrans->buf_size;
377 		vvs->buf_size_min = ptrans->buf_size_min;
378 		vvs->buf_size_max = ptrans->buf_size_max;
379 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
380 	} else {
381 		vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
382 		vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
383 		vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
384 	}
385 
386 	vvs->buf_alloc = vvs->buf_size;
387 
388 	spin_lock_init(&vvs->rx_lock);
389 	spin_lock_init(&vvs->tx_lock);
390 	INIT_LIST_HEAD(&vvs->rx_queue);
391 
392 	return 0;
393 }
394 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
395 
396 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
397 {
398 	struct virtio_vsock_sock *vvs = vsk->trans;
399 
400 	return vvs->buf_size;
401 }
402 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
403 
404 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
405 {
406 	struct virtio_vsock_sock *vvs = vsk->trans;
407 
408 	return vvs->buf_size_min;
409 }
410 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
411 
412 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
413 {
414 	struct virtio_vsock_sock *vvs = vsk->trans;
415 
416 	return vvs->buf_size_max;
417 }
418 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
419 
420 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
421 {
422 	struct virtio_vsock_sock *vvs = vsk->trans;
423 
424 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
425 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
426 	if (val < vvs->buf_size_min)
427 		vvs->buf_size_min = val;
428 	if (val > vvs->buf_size_max)
429 		vvs->buf_size_max = val;
430 	vvs->buf_size = val;
431 	vvs->buf_alloc = val;
432 }
433 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
434 
435 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
436 {
437 	struct virtio_vsock_sock *vvs = vsk->trans;
438 
439 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
440 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
441 	if (val > vvs->buf_size)
442 		vvs->buf_size = val;
443 	vvs->buf_size_min = val;
444 }
445 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
446 
447 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
448 {
449 	struct virtio_vsock_sock *vvs = vsk->trans;
450 
451 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
452 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
453 	if (val < vvs->buf_size)
454 		vvs->buf_size = val;
455 	vvs->buf_size_max = val;
456 }
457 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
458 
459 int
460 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
461 				size_t target,
462 				bool *data_ready_now)
463 {
464 	if (vsock_stream_has_data(vsk))
465 		*data_ready_now = true;
466 	else
467 		*data_ready_now = false;
468 
469 	return 0;
470 }
471 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
472 
473 int
474 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
475 				 size_t target,
476 				 bool *space_avail_now)
477 {
478 	s64 free_space;
479 
480 	free_space = vsock_stream_has_space(vsk);
481 	if (free_space > 0)
482 		*space_avail_now = true;
483 	else if (free_space == 0)
484 		*space_avail_now = false;
485 
486 	return 0;
487 }
488 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
489 
490 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
491 	size_t target, struct vsock_transport_recv_notify_data *data)
492 {
493 	return 0;
494 }
495 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
496 
497 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
498 	size_t target, struct vsock_transport_recv_notify_data *data)
499 {
500 	return 0;
501 }
502 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
503 
504 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
505 	size_t target, struct vsock_transport_recv_notify_data *data)
506 {
507 	return 0;
508 }
509 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
510 
511 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
512 	size_t target, ssize_t copied, bool data_read,
513 	struct vsock_transport_recv_notify_data *data)
514 {
515 	return 0;
516 }
517 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
518 
519 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
520 	struct vsock_transport_send_notify_data *data)
521 {
522 	return 0;
523 }
524 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
525 
526 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
527 	struct vsock_transport_send_notify_data *data)
528 {
529 	return 0;
530 }
531 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
532 
533 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
534 	struct vsock_transport_send_notify_data *data)
535 {
536 	return 0;
537 }
538 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
539 
540 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
541 	ssize_t written, struct vsock_transport_send_notify_data *data)
542 {
543 	return 0;
544 }
545 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
546 
547 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
548 {
549 	struct virtio_vsock_sock *vvs = vsk->trans;
550 
551 	return vvs->buf_size;
552 }
553 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
554 
555 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
556 {
557 	return true;
558 }
559 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
560 
561 bool virtio_transport_stream_allow(u32 cid, u32 port)
562 {
563 	return true;
564 }
565 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
566 
567 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
568 				struct sockaddr_vm *addr)
569 {
570 	return -EOPNOTSUPP;
571 }
572 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
573 
574 bool virtio_transport_dgram_allow(u32 cid, u32 port)
575 {
576 	return false;
577 }
578 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
579 
580 int virtio_transport_connect(struct vsock_sock *vsk)
581 {
582 	struct virtio_vsock_pkt_info info = {
583 		.op = VIRTIO_VSOCK_OP_REQUEST,
584 		.type = VIRTIO_VSOCK_TYPE_STREAM,
585 		.vsk = vsk,
586 	};
587 
588 	return virtio_transport_send_pkt_info(vsk, &info);
589 }
590 EXPORT_SYMBOL_GPL(virtio_transport_connect);
591 
592 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
593 {
594 	struct virtio_vsock_pkt_info info = {
595 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
596 		.type = VIRTIO_VSOCK_TYPE_STREAM,
597 		.flags = (mode & RCV_SHUTDOWN ?
598 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
599 			 (mode & SEND_SHUTDOWN ?
600 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
601 		.vsk = vsk,
602 	};
603 
604 	return virtio_transport_send_pkt_info(vsk, &info);
605 }
606 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
607 
608 int
609 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
610 			       struct sockaddr_vm *remote_addr,
611 			       struct msghdr *msg,
612 			       size_t dgram_len)
613 {
614 	return -EOPNOTSUPP;
615 }
616 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
617 
618 ssize_t
619 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
620 				struct msghdr *msg,
621 				size_t len)
622 {
623 	struct virtio_vsock_pkt_info info = {
624 		.op = VIRTIO_VSOCK_OP_RW,
625 		.type = VIRTIO_VSOCK_TYPE_STREAM,
626 		.msg = msg,
627 		.pkt_len = len,
628 		.vsk = vsk,
629 	};
630 
631 	return virtio_transport_send_pkt_info(vsk, &info);
632 }
633 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
634 
635 void virtio_transport_destruct(struct vsock_sock *vsk)
636 {
637 	struct virtio_vsock_sock *vvs = vsk->trans;
638 
639 	kfree(vvs);
640 }
641 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
642 
643 static int virtio_transport_reset(struct vsock_sock *vsk,
644 				  struct virtio_vsock_pkt *pkt)
645 {
646 	struct virtio_vsock_pkt_info info = {
647 		.op = VIRTIO_VSOCK_OP_RST,
648 		.type = VIRTIO_VSOCK_TYPE_STREAM,
649 		.reply = !!pkt,
650 		.vsk = vsk,
651 	};
652 
653 	/* Send RST only if the original pkt is not a RST pkt */
654 	if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
655 		return 0;
656 
657 	return virtio_transport_send_pkt_info(vsk, &info);
658 }
659 
660 /* Normally packets are associated with a socket.  There may be no socket if an
661  * attempt was made to connect to a socket that does not exist.
662  */
663 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
664 {
665 	struct virtio_vsock_pkt_info info = {
666 		.op = VIRTIO_VSOCK_OP_RST,
667 		.type = le16_to_cpu(pkt->hdr.type),
668 		.reply = true,
669 	};
670 
671 	/* Send RST only if the original pkt is not a RST pkt */
672 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
673 		return 0;
674 
675 	pkt = virtio_transport_alloc_pkt(&info, 0,
676 					 le64_to_cpu(pkt->hdr.dst_cid),
677 					 le32_to_cpu(pkt->hdr.dst_port),
678 					 le64_to_cpu(pkt->hdr.src_cid),
679 					 le32_to_cpu(pkt->hdr.src_port));
680 	if (!pkt)
681 		return -ENOMEM;
682 
683 	return virtio_transport_get_ops()->send_pkt(pkt);
684 }
685 
686 static void virtio_transport_wait_close(struct sock *sk, long timeout)
687 {
688 	if (timeout) {
689 		DEFINE_WAIT_FUNC(wait, woken_wake_function);
690 
691 		add_wait_queue(sk_sleep(sk), &wait);
692 
693 		do {
694 			if (sk_wait_event(sk, &timeout,
695 					  sock_flag(sk, SOCK_DONE), &wait))
696 				break;
697 		} while (!signal_pending(current) && timeout);
698 
699 		remove_wait_queue(sk_sleep(sk), &wait);
700 	}
701 }
702 
703 static void virtio_transport_do_close(struct vsock_sock *vsk,
704 				      bool cancel_timeout)
705 {
706 	struct sock *sk = sk_vsock(vsk);
707 
708 	sock_set_flag(sk, SOCK_DONE);
709 	vsk->peer_shutdown = SHUTDOWN_MASK;
710 	if (vsock_stream_has_data(vsk) <= 0)
711 		sk->sk_state = TCP_CLOSING;
712 	sk->sk_state_change(sk);
713 
714 	if (vsk->close_work_scheduled &&
715 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
716 		vsk->close_work_scheduled = false;
717 
718 		vsock_remove_sock(vsk);
719 
720 		/* Release refcnt obtained when we scheduled the timeout */
721 		sock_put(sk);
722 	}
723 }
724 
725 static void virtio_transport_close_timeout(struct work_struct *work)
726 {
727 	struct vsock_sock *vsk =
728 		container_of(work, struct vsock_sock, close_work.work);
729 	struct sock *sk = sk_vsock(vsk);
730 
731 	sock_hold(sk);
732 	lock_sock(sk);
733 
734 	if (!sock_flag(sk, SOCK_DONE)) {
735 		(void)virtio_transport_reset(vsk, NULL);
736 
737 		virtio_transport_do_close(vsk, false);
738 	}
739 
740 	vsk->close_work_scheduled = false;
741 
742 	release_sock(sk);
743 	sock_put(sk);
744 }
745 
746 /* User context, vsk->sk is locked */
747 static bool virtio_transport_close(struct vsock_sock *vsk)
748 {
749 	struct sock *sk = &vsk->sk;
750 
751 	if (!(sk->sk_state == TCP_ESTABLISHED ||
752 	      sk->sk_state == TCP_CLOSING))
753 		return true;
754 
755 	/* Already received SHUTDOWN from peer, reply with RST */
756 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
757 		(void)virtio_transport_reset(vsk, NULL);
758 		return true;
759 	}
760 
761 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
762 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
763 
764 	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
765 		virtio_transport_wait_close(sk, sk->sk_lingertime);
766 
767 	if (sock_flag(sk, SOCK_DONE)) {
768 		return true;
769 	}
770 
771 	sock_hold(sk);
772 	INIT_DELAYED_WORK(&vsk->close_work,
773 			  virtio_transport_close_timeout);
774 	vsk->close_work_scheduled = true;
775 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
776 	return false;
777 }
778 
779 void virtio_transport_release(struct vsock_sock *vsk)
780 {
781 	struct sock *sk = &vsk->sk;
782 	bool remove_sock = true;
783 
784 	lock_sock(sk);
785 	if (sk->sk_type == SOCK_STREAM)
786 		remove_sock = virtio_transport_close(vsk);
787 	release_sock(sk);
788 
789 	if (remove_sock)
790 		vsock_remove_sock(vsk);
791 }
792 EXPORT_SYMBOL_GPL(virtio_transport_release);
793 
794 static int
795 virtio_transport_recv_connecting(struct sock *sk,
796 				 struct virtio_vsock_pkt *pkt)
797 {
798 	struct vsock_sock *vsk = vsock_sk(sk);
799 	int err;
800 	int skerr;
801 
802 	switch (le16_to_cpu(pkt->hdr.op)) {
803 	case VIRTIO_VSOCK_OP_RESPONSE:
804 		sk->sk_state = TCP_ESTABLISHED;
805 		sk->sk_socket->state = SS_CONNECTED;
806 		vsock_insert_connected(vsk);
807 		sk->sk_state_change(sk);
808 		break;
809 	case VIRTIO_VSOCK_OP_INVALID:
810 		break;
811 	case VIRTIO_VSOCK_OP_RST:
812 		skerr = ECONNRESET;
813 		err = 0;
814 		goto destroy;
815 	default:
816 		skerr = EPROTO;
817 		err = -EINVAL;
818 		goto destroy;
819 	}
820 	return 0;
821 
822 destroy:
823 	virtio_transport_reset(vsk, pkt);
824 	sk->sk_state = TCP_CLOSE;
825 	sk->sk_err = skerr;
826 	sk->sk_error_report(sk);
827 	return err;
828 }
829 
830 static int
831 virtio_transport_recv_connected(struct sock *sk,
832 				struct virtio_vsock_pkt *pkt)
833 {
834 	struct vsock_sock *vsk = vsock_sk(sk);
835 	struct virtio_vsock_sock *vvs = vsk->trans;
836 	int err = 0;
837 
838 	switch (le16_to_cpu(pkt->hdr.op)) {
839 	case VIRTIO_VSOCK_OP_RW:
840 		pkt->len = le32_to_cpu(pkt->hdr.len);
841 		pkt->off = 0;
842 
843 		spin_lock_bh(&vvs->rx_lock);
844 		virtio_transport_inc_rx_pkt(vvs, pkt);
845 		list_add_tail(&pkt->list, &vvs->rx_queue);
846 		spin_unlock_bh(&vvs->rx_lock);
847 
848 		sk->sk_data_ready(sk);
849 		return err;
850 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
851 		sk->sk_write_space(sk);
852 		break;
853 	case VIRTIO_VSOCK_OP_SHUTDOWN:
854 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
855 			vsk->peer_shutdown |= RCV_SHUTDOWN;
856 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
857 			vsk->peer_shutdown |= SEND_SHUTDOWN;
858 		if (vsk->peer_shutdown == SHUTDOWN_MASK &&
859 		    vsock_stream_has_data(vsk) <= 0)
860 			sk->sk_state = TCP_CLOSING;
861 		if (le32_to_cpu(pkt->hdr.flags))
862 			sk->sk_state_change(sk);
863 		break;
864 	case VIRTIO_VSOCK_OP_RST:
865 		virtio_transport_do_close(vsk, true);
866 		break;
867 	default:
868 		err = -EINVAL;
869 		break;
870 	}
871 
872 	virtio_transport_free_pkt(pkt);
873 	return err;
874 }
875 
876 static void
877 virtio_transport_recv_disconnecting(struct sock *sk,
878 				    struct virtio_vsock_pkt *pkt)
879 {
880 	struct vsock_sock *vsk = vsock_sk(sk);
881 
882 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
883 		virtio_transport_do_close(vsk, true);
884 }
885 
886 static int
887 virtio_transport_send_response(struct vsock_sock *vsk,
888 			       struct virtio_vsock_pkt *pkt)
889 {
890 	struct virtio_vsock_pkt_info info = {
891 		.op = VIRTIO_VSOCK_OP_RESPONSE,
892 		.type = VIRTIO_VSOCK_TYPE_STREAM,
893 		.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
894 		.remote_port = le32_to_cpu(pkt->hdr.src_port),
895 		.reply = true,
896 		.vsk = vsk,
897 	};
898 
899 	return virtio_transport_send_pkt_info(vsk, &info);
900 }
901 
902 /* Handle server socket */
903 static int
904 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
905 {
906 	struct vsock_sock *vsk = vsock_sk(sk);
907 	struct vsock_sock *vchild;
908 	struct sock *child;
909 
910 	if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
911 		virtio_transport_reset(vsk, pkt);
912 		return -EINVAL;
913 	}
914 
915 	if (sk_acceptq_is_full(sk)) {
916 		virtio_transport_reset(vsk, pkt);
917 		return -ENOMEM;
918 	}
919 
920 	child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
921 			       sk->sk_type, 0);
922 	if (!child) {
923 		virtio_transport_reset(vsk, pkt);
924 		return -ENOMEM;
925 	}
926 
927 	sk->sk_ack_backlog++;
928 
929 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
930 
931 	child->sk_state = TCP_ESTABLISHED;
932 
933 	vchild = vsock_sk(child);
934 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
935 			le32_to_cpu(pkt->hdr.dst_port));
936 	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
937 			le32_to_cpu(pkt->hdr.src_port));
938 
939 	vsock_insert_connected(vchild);
940 	vsock_enqueue_accept(sk, child);
941 	virtio_transport_send_response(vchild, pkt);
942 
943 	release_sock(child);
944 
945 	sk->sk_data_ready(sk);
946 	return 0;
947 }
948 
949 static bool virtio_transport_space_update(struct sock *sk,
950 					  struct virtio_vsock_pkt *pkt)
951 {
952 	struct vsock_sock *vsk = vsock_sk(sk);
953 	struct virtio_vsock_sock *vvs = vsk->trans;
954 	bool space_available;
955 
956 	/* buf_alloc and fwd_cnt is always included in the hdr */
957 	spin_lock_bh(&vvs->tx_lock);
958 	vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
959 	vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
960 	space_available = virtio_transport_has_space(vsk);
961 	spin_unlock_bh(&vvs->tx_lock);
962 	return space_available;
963 }
964 
965 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
966  * lock.
967  */
968 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
969 {
970 	struct sockaddr_vm src, dst;
971 	struct vsock_sock *vsk;
972 	struct sock *sk;
973 	bool space_available;
974 
975 	vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
976 			le32_to_cpu(pkt->hdr.src_port));
977 	vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
978 			le32_to_cpu(pkt->hdr.dst_port));
979 
980 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
981 					dst.svm_cid, dst.svm_port,
982 					le32_to_cpu(pkt->hdr.len),
983 					le16_to_cpu(pkt->hdr.type),
984 					le16_to_cpu(pkt->hdr.op),
985 					le32_to_cpu(pkt->hdr.flags),
986 					le32_to_cpu(pkt->hdr.buf_alloc),
987 					le32_to_cpu(pkt->hdr.fwd_cnt));
988 
989 	if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
990 		(void)virtio_transport_reset_no_sock(pkt);
991 		goto free_pkt;
992 	}
993 
994 	/* The socket must be in connected or bound table
995 	 * otherwise send reset back
996 	 */
997 	sk = vsock_find_connected_socket(&src, &dst);
998 	if (!sk) {
999 		sk = vsock_find_bound_socket(&dst);
1000 		if (!sk) {
1001 			(void)virtio_transport_reset_no_sock(pkt);
1002 			goto free_pkt;
1003 		}
1004 	}
1005 
1006 	vsk = vsock_sk(sk);
1007 
1008 	space_available = virtio_transport_space_update(sk, pkt);
1009 
1010 	lock_sock(sk);
1011 
1012 	/* Update CID in case it has changed after a transport reset event */
1013 	vsk->local_addr.svm_cid = dst.svm_cid;
1014 
1015 	if (space_available)
1016 		sk->sk_write_space(sk);
1017 
1018 	switch (sk->sk_state) {
1019 	case TCP_LISTEN:
1020 		virtio_transport_recv_listen(sk, pkt);
1021 		virtio_transport_free_pkt(pkt);
1022 		break;
1023 	case TCP_SYN_SENT:
1024 		virtio_transport_recv_connecting(sk, pkt);
1025 		virtio_transport_free_pkt(pkt);
1026 		break;
1027 	case TCP_ESTABLISHED:
1028 		virtio_transport_recv_connected(sk, pkt);
1029 		break;
1030 	case TCP_CLOSING:
1031 		virtio_transport_recv_disconnecting(sk, pkt);
1032 		virtio_transport_free_pkt(pkt);
1033 		break;
1034 	default:
1035 		virtio_transport_free_pkt(pkt);
1036 		break;
1037 	}
1038 	release_sock(sk);
1039 
1040 	/* Release refcnt obtained when we fetched this socket out of the
1041 	 * bound or connected list.
1042 	 */
1043 	sock_put(sk);
1044 	return;
1045 
1046 free_pkt:
1047 	virtio_transport_free_pkt(pkt);
1048 }
1049 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1050 
1051 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1052 {
1053 	kfree(pkt->buf);
1054 	kfree(pkt);
1055 }
1056 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1057 
1058 MODULE_LICENSE("GPL v2");
1059 MODULE_AUTHOR("Asias He");
1060 MODULE_DESCRIPTION("common code for virtio vsock");
1061