1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18 
19 #include <net/sock.h>
20 #include <net/af_vsock.h>
21 
22 #define CREATE_TRACE_POINTS
23 #include <trace/events/vsock_virtio_transport_common.h>
24 
25 /* How long to wait for graceful shutdown of a connection */
26 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
27 
28 static const struct virtio_transport *virtio_transport_get_ops(void)
29 {
30 	const struct vsock_transport *t = vsock_core_get_transport();
31 
32 	return container_of(t, struct virtio_transport, transport);
33 }
34 
35 struct virtio_vsock_pkt *
36 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
37 			   size_t len,
38 			   u32 src_cid,
39 			   u32 src_port,
40 			   u32 dst_cid,
41 			   u32 dst_port)
42 {
43 	struct virtio_vsock_pkt *pkt;
44 	int err;
45 
46 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
47 	if (!pkt)
48 		return NULL;
49 
50 	pkt->hdr.type		= cpu_to_le16(info->type);
51 	pkt->hdr.op		= cpu_to_le16(info->op);
52 	pkt->hdr.src_cid	= cpu_to_le64(src_cid);
53 	pkt->hdr.dst_cid	= cpu_to_le64(dst_cid);
54 	pkt->hdr.src_port	= cpu_to_le32(src_port);
55 	pkt->hdr.dst_port	= cpu_to_le32(dst_port);
56 	pkt->hdr.flags		= cpu_to_le32(info->flags);
57 	pkt->len		= len;
58 	pkt->hdr.len		= cpu_to_le32(len);
59 	pkt->reply		= info->reply;
60 
61 	if (info->msg && len > 0) {
62 		pkt->buf = kmalloc(len, GFP_KERNEL);
63 		if (!pkt->buf)
64 			goto out_pkt;
65 		err = memcpy_from_msg(pkt->buf, info->msg, len);
66 		if (err)
67 			goto out;
68 	}
69 
70 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
71 					 dst_cid, dst_port,
72 					 len,
73 					 info->type,
74 					 info->op,
75 					 info->flags);
76 
77 	return pkt;
78 
79 out:
80 	kfree(pkt->buf);
81 out_pkt:
82 	kfree(pkt);
83 	return NULL;
84 }
85 EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
86 
87 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
88 					  struct virtio_vsock_pkt_info *info)
89 {
90 	u32 src_cid, src_port, dst_cid, dst_port;
91 	struct virtio_vsock_sock *vvs;
92 	struct virtio_vsock_pkt *pkt;
93 	u32 pkt_len = info->pkt_len;
94 
95 	src_cid = vm_sockets_get_local_cid();
96 	src_port = vsk->local_addr.svm_port;
97 	if (!info->remote_cid) {
98 		dst_cid	= vsk->remote_addr.svm_cid;
99 		dst_port = vsk->remote_addr.svm_port;
100 	} else {
101 		dst_cid = info->remote_cid;
102 		dst_port = info->remote_port;
103 	}
104 
105 	vvs = vsk->trans;
106 
107 	/* we can send less than pkt_len bytes */
108 	if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
109 		pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
110 
111 	/* virtio_transport_get_credit might return less than pkt_len credit */
112 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
113 
114 	/* Do not send zero length OP_RW pkt */
115 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
116 		return pkt_len;
117 
118 	pkt = virtio_transport_alloc_pkt(info, pkt_len,
119 					 src_cid, src_port,
120 					 dst_cid, dst_port);
121 	if (!pkt) {
122 		virtio_transport_put_credit(vvs, pkt_len);
123 		return -ENOMEM;
124 	}
125 
126 	virtio_transport_inc_tx_pkt(vvs, pkt);
127 
128 	return virtio_transport_get_ops()->send_pkt(pkt);
129 }
130 
131 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
132 					struct virtio_vsock_pkt *pkt)
133 {
134 	vvs->rx_bytes += pkt->len;
135 }
136 
137 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
138 					struct virtio_vsock_pkt *pkt)
139 {
140 	vvs->rx_bytes -= pkt->len;
141 	vvs->fwd_cnt += pkt->len;
142 }
143 
144 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
145 {
146 	spin_lock_bh(&vvs->tx_lock);
147 	pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
148 	pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
149 	spin_unlock_bh(&vvs->tx_lock);
150 }
151 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
152 
153 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
154 {
155 	u32 ret;
156 
157 	spin_lock_bh(&vvs->tx_lock);
158 	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
159 	if (ret > credit)
160 		ret = credit;
161 	vvs->tx_cnt += ret;
162 	spin_unlock_bh(&vvs->tx_lock);
163 
164 	return ret;
165 }
166 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
167 
168 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
169 {
170 	spin_lock_bh(&vvs->tx_lock);
171 	vvs->tx_cnt -= credit;
172 	spin_unlock_bh(&vvs->tx_lock);
173 }
174 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
175 
176 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
177 					       int type,
178 					       struct virtio_vsock_hdr *hdr)
179 {
180 	struct virtio_vsock_pkt_info info = {
181 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
182 		.type = type,
183 	};
184 
185 	return virtio_transport_send_pkt_info(vsk, &info);
186 }
187 
188 static ssize_t
189 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
190 				   struct msghdr *msg,
191 				   size_t len)
192 {
193 	struct virtio_vsock_sock *vvs = vsk->trans;
194 	struct virtio_vsock_pkt *pkt;
195 	size_t bytes, total = 0;
196 	int err = -EFAULT;
197 
198 	spin_lock_bh(&vvs->rx_lock);
199 	while (total < len && !list_empty(&vvs->rx_queue)) {
200 		pkt = list_first_entry(&vvs->rx_queue,
201 				       struct virtio_vsock_pkt, list);
202 
203 		bytes = len - total;
204 		if (bytes > pkt->len - pkt->off)
205 			bytes = pkt->len - pkt->off;
206 
207 		/* sk_lock is held by caller so no one else can dequeue.
208 		 * Unlock rx_lock since memcpy_to_msg() may sleep.
209 		 */
210 		spin_unlock_bh(&vvs->rx_lock);
211 
212 		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
213 		if (err)
214 			goto out;
215 
216 		spin_lock_bh(&vvs->rx_lock);
217 
218 		total += bytes;
219 		pkt->off += bytes;
220 		if (pkt->off == pkt->len) {
221 			virtio_transport_dec_rx_pkt(vvs, pkt);
222 			list_del(&pkt->list);
223 			virtio_transport_free_pkt(pkt);
224 		}
225 	}
226 	spin_unlock_bh(&vvs->rx_lock);
227 
228 	/* Send a credit pkt to peer */
229 	virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
230 					    NULL);
231 
232 	return total;
233 
234 out:
235 	if (total)
236 		err = total;
237 	return err;
238 }
239 
240 ssize_t
241 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
242 				struct msghdr *msg,
243 				size_t len, int flags)
244 {
245 	if (flags & MSG_PEEK)
246 		return -EOPNOTSUPP;
247 
248 	return virtio_transport_stream_do_dequeue(vsk, msg, len);
249 }
250 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
251 
252 int
253 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
254 			       struct msghdr *msg,
255 			       size_t len, int flags)
256 {
257 	return -EOPNOTSUPP;
258 }
259 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
260 
261 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
262 {
263 	struct virtio_vsock_sock *vvs = vsk->trans;
264 	s64 bytes;
265 
266 	spin_lock_bh(&vvs->rx_lock);
267 	bytes = vvs->rx_bytes;
268 	spin_unlock_bh(&vvs->rx_lock);
269 
270 	return bytes;
271 }
272 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
273 
274 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
275 {
276 	struct virtio_vsock_sock *vvs = vsk->trans;
277 	s64 bytes;
278 
279 	bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
280 	if (bytes < 0)
281 		bytes = 0;
282 
283 	return bytes;
284 }
285 
286 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
287 {
288 	struct virtio_vsock_sock *vvs = vsk->trans;
289 	s64 bytes;
290 
291 	spin_lock_bh(&vvs->tx_lock);
292 	bytes = virtio_transport_has_space(vsk);
293 	spin_unlock_bh(&vvs->tx_lock);
294 
295 	return bytes;
296 }
297 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
298 
299 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
300 				    struct vsock_sock *psk)
301 {
302 	struct virtio_vsock_sock *vvs;
303 
304 	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
305 	if (!vvs)
306 		return -ENOMEM;
307 
308 	vsk->trans = vvs;
309 	vvs->vsk = vsk;
310 	if (psk) {
311 		struct virtio_vsock_sock *ptrans = psk->trans;
312 
313 		vvs->buf_size	= ptrans->buf_size;
314 		vvs->buf_size_min = ptrans->buf_size_min;
315 		vvs->buf_size_max = ptrans->buf_size_max;
316 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
317 	} else {
318 		vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
319 		vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
320 		vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
321 	}
322 
323 	vvs->buf_alloc = vvs->buf_size;
324 
325 	spin_lock_init(&vvs->rx_lock);
326 	spin_lock_init(&vvs->tx_lock);
327 	INIT_LIST_HEAD(&vvs->rx_queue);
328 
329 	return 0;
330 }
331 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
332 
333 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
334 {
335 	struct virtio_vsock_sock *vvs = vsk->trans;
336 
337 	return vvs->buf_size;
338 }
339 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
340 
341 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
342 {
343 	struct virtio_vsock_sock *vvs = vsk->trans;
344 
345 	return vvs->buf_size_min;
346 }
347 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
348 
349 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
350 {
351 	struct virtio_vsock_sock *vvs = vsk->trans;
352 
353 	return vvs->buf_size_max;
354 }
355 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
356 
357 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
358 {
359 	struct virtio_vsock_sock *vvs = vsk->trans;
360 
361 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
362 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
363 	if (val < vvs->buf_size_min)
364 		vvs->buf_size_min = val;
365 	if (val > vvs->buf_size_max)
366 		vvs->buf_size_max = val;
367 	vvs->buf_size = val;
368 	vvs->buf_alloc = val;
369 }
370 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
371 
372 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
373 {
374 	struct virtio_vsock_sock *vvs = vsk->trans;
375 
376 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
377 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
378 	if (val > vvs->buf_size)
379 		vvs->buf_size = val;
380 	vvs->buf_size_min = val;
381 }
382 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
383 
384 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
385 {
386 	struct virtio_vsock_sock *vvs = vsk->trans;
387 
388 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
389 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
390 	if (val < vvs->buf_size)
391 		vvs->buf_size = val;
392 	vvs->buf_size_max = val;
393 }
394 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
395 
396 int
397 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
398 				size_t target,
399 				bool *data_ready_now)
400 {
401 	if (vsock_stream_has_data(vsk))
402 		*data_ready_now = true;
403 	else
404 		*data_ready_now = false;
405 
406 	return 0;
407 }
408 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
409 
410 int
411 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
412 				 size_t target,
413 				 bool *space_avail_now)
414 {
415 	s64 free_space;
416 
417 	free_space = vsock_stream_has_space(vsk);
418 	if (free_space > 0)
419 		*space_avail_now = true;
420 	else if (free_space == 0)
421 		*space_avail_now = false;
422 
423 	return 0;
424 }
425 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
426 
427 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
428 	size_t target, struct vsock_transport_recv_notify_data *data)
429 {
430 	return 0;
431 }
432 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
433 
434 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
435 	size_t target, struct vsock_transport_recv_notify_data *data)
436 {
437 	return 0;
438 }
439 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
440 
441 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
442 	size_t target, struct vsock_transport_recv_notify_data *data)
443 {
444 	return 0;
445 }
446 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
447 
448 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
449 	size_t target, ssize_t copied, bool data_read,
450 	struct vsock_transport_recv_notify_data *data)
451 {
452 	return 0;
453 }
454 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
455 
456 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
457 	struct vsock_transport_send_notify_data *data)
458 {
459 	return 0;
460 }
461 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
462 
463 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
464 	struct vsock_transport_send_notify_data *data)
465 {
466 	return 0;
467 }
468 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
469 
470 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
471 	struct vsock_transport_send_notify_data *data)
472 {
473 	return 0;
474 }
475 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
476 
477 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
478 	ssize_t written, struct vsock_transport_send_notify_data *data)
479 {
480 	return 0;
481 }
482 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
483 
484 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
485 {
486 	struct virtio_vsock_sock *vvs = vsk->trans;
487 
488 	return vvs->buf_size;
489 }
490 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
491 
492 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
493 {
494 	return true;
495 }
496 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
497 
498 bool virtio_transport_stream_allow(u32 cid, u32 port)
499 {
500 	return true;
501 }
502 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
503 
504 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
505 				struct sockaddr_vm *addr)
506 {
507 	return -EOPNOTSUPP;
508 }
509 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
510 
511 bool virtio_transport_dgram_allow(u32 cid, u32 port)
512 {
513 	return false;
514 }
515 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
516 
517 int virtio_transport_connect(struct vsock_sock *vsk)
518 {
519 	struct virtio_vsock_pkt_info info = {
520 		.op = VIRTIO_VSOCK_OP_REQUEST,
521 		.type = VIRTIO_VSOCK_TYPE_STREAM,
522 	};
523 
524 	return virtio_transport_send_pkt_info(vsk, &info);
525 }
526 EXPORT_SYMBOL_GPL(virtio_transport_connect);
527 
528 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
529 {
530 	struct virtio_vsock_pkt_info info = {
531 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
532 		.type = VIRTIO_VSOCK_TYPE_STREAM,
533 		.flags = (mode & RCV_SHUTDOWN ?
534 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
535 			 (mode & SEND_SHUTDOWN ?
536 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
537 	};
538 
539 	return virtio_transport_send_pkt_info(vsk, &info);
540 }
541 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
542 
543 int
544 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
545 			       struct sockaddr_vm *remote_addr,
546 			       struct msghdr *msg,
547 			       size_t dgram_len)
548 {
549 	return -EOPNOTSUPP;
550 }
551 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
552 
553 ssize_t
554 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
555 				struct msghdr *msg,
556 				size_t len)
557 {
558 	struct virtio_vsock_pkt_info info = {
559 		.op = VIRTIO_VSOCK_OP_RW,
560 		.type = VIRTIO_VSOCK_TYPE_STREAM,
561 		.msg = msg,
562 		.pkt_len = len,
563 	};
564 
565 	return virtio_transport_send_pkt_info(vsk, &info);
566 }
567 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
568 
569 void virtio_transport_destruct(struct vsock_sock *vsk)
570 {
571 	struct virtio_vsock_sock *vvs = vsk->trans;
572 
573 	kfree(vvs);
574 }
575 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
576 
577 static int virtio_transport_reset(struct vsock_sock *vsk,
578 				  struct virtio_vsock_pkt *pkt)
579 {
580 	struct virtio_vsock_pkt_info info = {
581 		.op = VIRTIO_VSOCK_OP_RST,
582 		.type = VIRTIO_VSOCK_TYPE_STREAM,
583 		.reply = !!pkt,
584 	};
585 
586 	/* Send RST only if the original pkt is not a RST pkt */
587 	if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
588 		return 0;
589 
590 	return virtio_transport_send_pkt_info(vsk, &info);
591 }
592 
593 /* Normally packets are associated with a socket.  There may be no socket if an
594  * attempt was made to connect to a socket that does not exist.
595  */
596 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
597 {
598 	struct virtio_vsock_pkt_info info = {
599 		.op = VIRTIO_VSOCK_OP_RST,
600 		.type = le16_to_cpu(pkt->hdr.type),
601 		.reply = true,
602 	};
603 
604 	/* Send RST only if the original pkt is not a RST pkt */
605 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
606 		return 0;
607 
608 	pkt = virtio_transport_alloc_pkt(&info, 0,
609 					 le32_to_cpu(pkt->hdr.dst_cid),
610 					 le32_to_cpu(pkt->hdr.dst_port),
611 					 le32_to_cpu(pkt->hdr.src_cid),
612 					 le32_to_cpu(pkt->hdr.src_port));
613 	if (!pkt)
614 		return -ENOMEM;
615 
616 	return virtio_transport_get_ops()->send_pkt(pkt);
617 }
618 
619 static void virtio_transport_wait_close(struct sock *sk, long timeout)
620 {
621 	if (timeout) {
622 		DEFINE_WAIT(wait);
623 
624 		do {
625 			prepare_to_wait(sk_sleep(sk), &wait,
626 					TASK_INTERRUPTIBLE);
627 			if (sk_wait_event(sk, &timeout,
628 					  sock_flag(sk, SOCK_DONE)))
629 				break;
630 		} while (!signal_pending(current) && timeout);
631 
632 		finish_wait(sk_sleep(sk), &wait);
633 	}
634 }
635 
636 static void virtio_transport_do_close(struct vsock_sock *vsk,
637 				      bool cancel_timeout)
638 {
639 	struct sock *sk = sk_vsock(vsk);
640 
641 	sock_set_flag(sk, SOCK_DONE);
642 	vsk->peer_shutdown = SHUTDOWN_MASK;
643 	if (vsock_stream_has_data(vsk) <= 0)
644 		sk->sk_state = SS_DISCONNECTING;
645 	sk->sk_state_change(sk);
646 
647 	if (vsk->close_work_scheduled &&
648 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
649 		vsk->close_work_scheduled = false;
650 
651 		vsock_remove_sock(vsk);
652 
653 		/* Release refcnt obtained when we scheduled the timeout */
654 		sock_put(sk);
655 	}
656 }
657 
658 static void virtio_transport_close_timeout(struct work_struct *work)
659 {
660 	struct vsock_sock *vsk =
661 		container_of(work, struct vsock_sock, close_work.work);
662 	struct sock *sk = sk_vsock(vsk);
663 
664 	sock_hold(sk);
665 	lock_sock(sk);
666 
667 	if (!sock_flag(sk, SOCK_DONE)) {
668 		(void)virtio_transport_reset(vsk, NULL);
669 
670 		virtio_transport_do_close(vsk, false);
671 	}
672 
673 	vsk->close_work_scheduled = false;
674 
675 	release_sock(sk);
676 	sock_put(sk);
677 }
678 
679 /* User context, vsk->sk is locked */
680 static bool virtio_transport_close(struct vsock_sock *vsk)
681 {
682 	struct sock *sk = &vsk->sk;
683 
684 	if (!(sk->sk_state == SS_CONNECTED ||
685 	      sk->sk_state == SS_DISCONNECTING))
686 		return true;
687 
688 	/* Already received SHUTDOWN from peer, reply with RST */
689 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
690 		(void)virtio_transport_reset(vsk, NULL);
691 		return true;
692 	}
693 
694 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
695 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
696 
697 	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
698 		virtio_transport_wait_close(sk, sk->sk_lingertime);
699 
700 	if (sock_flag(sk, SOCK_DONE)) {
701 		return true;
702 	}
703 
704 	sock_hold(sk);
705 	INIT_DELAYED_WORK(&vsk->close_work,
706 			  virtio_transport_close_timeout);
707 	vsk->close_work_scheduled = true;
708 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
709 	return false;
710 }
711 
712 void virtio_transport_release(struct vsock_sock *vsk)
713 {
714 	struct sock *sk = &vsk->sk;
715 	bool remove_sock = true;
716 
717 	lock_sock(sk);
718 	if (sk->sk_type == SOCK_STREAM)
719 		remove_sock = virtio_transport_close(vsk);
720 	release_sock(sk);
721 
722 	if (remove_sock)
723 		vsock_remove_sock(vsk);
724 }
725 EXPORT_SYMBOL_GPL(virtio_transport_release);
726 
727 static int
728 virtio_transport_recv_connecting(struct sock *sk,
729 				 struct virtio_vsock_pkt *pkt)
730 {
731 	struct vsock_sock *vsk = vsock_sk(sk);
732 	int err;
733 	int skerr;
734 
735 	switch (le16_to_cpu(pkt->hdr.op)) {
736 	case VIRTIO_VSOCK_OP_RESPONSE:
737 		sk->sk_state = SS_CONNECTED;
738 		sk->sk_socket->state = SS_CONNECTED;
739 		vsock_insert_connected(vsk);
740 		sk->sk_state_change(sk);
741 		break;
742 	case VIRTIO_VSOCK_OP_INVALID:
743 		break;
744 	case VIRTIO_VSOCK_OP_RST:
745 		skerr = ECONNRESET;
746 		err = 0;
747 		goto destroy;
748 	default:
749 		skerr = EPROTO;
750 		err = -EINVAL;
751 		goto destroy;
752 	}
753 	return 0;
754 
755 destroy:
756 	virtio_transport_reset(vsk, pkt);
757 	sk->sk_state = SS_UNCONNECTED;
758 	sk->sk_err = skerr;
759 	sk->sk_error_report(sk);
760 	return err;
761 }
762 
763 static int
764 virtio_transport_recv_connected(struct sock *sk,
765 				struct virtio_vsock_pkt *pkt)
766 {
767 	struct vsock_sock *vsk = vsock_sk(sk);
768 	struct virtio_vsock_sock *vvs = vsk->trans;
769 	int err = 0;
770 
771 	switch (le16_to_cpu(pkt->hdr.op)) {
772 	case VIRTIO_VSOCK_OP_RW:
773 		pkt->len = le32_to_cpu(pkt->hdr.len);
774 		pkt->off = 0;
775 
776 		spin_lock_bh(&vvs->rx_lock);
777 		virtio_transport_inc_rx_pkt(vvs, pkt);
778 		list_add_tail(&pkt->list, &vvs->rx_queue);
779 		spin_unlock_bh(&vvs->rx_lock);
780 
781 		sk->sk_data_ready(sk);
782 		return err;
783 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
784 		sk->sk_write_space(sk);
785 		break;
786 	case VIRTIO_VSOCK_OP_SHUTDOWN:
787 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
788 			vsk->peer_shutdown |= RCV_SHUTDOWN;
789 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
790 			vsk->peer_shutdown |= SEND_SHUTDOWN;
791 		if (vsk->peer_shutdown == SHUTDOWN_MASK &&
792 		    vsock_stream_has_data(vsk) <= 0)
793 			sk->sk_state = SS_DISCONNECTING;
794 		if (le32_to_cpu(pkt->hdr.flags))
795 			sk->sk_state_change(sk);
796 		break;
797 	case VIRTIO_VSOCK_OP_RST:
798 		virtio_transport_do_close(vsk, true);
799 		break;
800 	default:
801 		err = -EINVAL;
802 		break;
803 	}
804 
805 	virtio_transport_free_pkt(pkt);
806 	return err;
807 }
808 
809 static void
810 virtio_transport_recv_disconnecting(struct sock *sk,
811 				    struct virtio_vsock_pkt *pkt)
812 {
813 	struct vsock_sock *vsk = vsock_sk(sk);
814 
815 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
816 		virtio_transport_do_close(vsk, true);
817 }
818 
819 static int
820 virtio_transport_send_response(struct vsock_sock *vsk,
821 			       struct virtio_vsock_pkt *pkt)
822 {
823 	struct virtio_vsock_pkt_info info = {
824 		.op = VIRTIO_VSOCK_OP_RESPONSE,
825 		.type = VIRTIO_VSOCK_TYPE_STREAM,
826 		.remote_cid = le32_to_cpu(pkt->hdr.src_cid),
827 		.remote_port = le32_to_cpu(pkt->hdr.src_port),
828 		.reply = true,
829 	};
830 
831 	return virtio_transport_send_pkt_info(vsk, &info);
832 }
833 
834 /* Handle server socket */
835 static int
836 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
837 {
838 	struct vsock_sock *vsk = vsock_sk(sk);
839 	struct vsock_sock *vchild;
840 	struct sock *child;
841 
842 	if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
843 		virtio_transport_reset(vsk, pkt);
844 		return -EINVAL;
845 	}
846 
847 	if (sk_acceptq_is_full(sk)) {
848 		virtio_transport_reset(vsk, pkt);
849 		return -ENOMEM;
850 	}
851 
852 	child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
853 			       sk->sk_type, 0);
854 	if (!child) {
855 		virtio_transport_reset(vsk, pkt);
856 		return -ENOMEM;
857 	}
858 
859 	sk->sk_ack_backlog++;
860 
861 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
862 
863 	child->sk_state = SS_CONNECTED;
864 
865 	vchild = vsock_sk(child);
866 	vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid),
867 			le32_to_cpu(pkt->hdr.dst_port));
868 	vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid),
869 			le32_to_cpu(pkt->hdr.src_port));
870 
871 	vsock_insert_connected(vchild);
872 	vsock_enqueue_accept(sk, child);
873 	virtio_transport_send_response(vchild, pkt);
874 
875 	release_sock(child);
876 
877 	sk->sk_data_ready(sk);
878 	return 0;
879 }
880 
881 static bool virtio_transport_space_update(struct sock *sk,
882 					  struct virtio_vsock_pkt *pkt)
883 {
884 	struct vsock_sock *vsk = vsock_sk(sk);
885 	struct virtio_vsock_sock *vvs = vsk->trans;
886 	bool space_available;
887 
888 	/* buf_alloc and fwd_cnt is always included in the hdr */
889 	spin_lock_bh(&vvs->tx_lock);
890 	vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
891 	vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
892 	space_available = virtio_transport_has_space(vsk);
893 	spin_unlock_bh(&vvs->tx_lock);
894 	return space_available;
895 }
896 
897 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
898  * lock.
899  */
900 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
901 {
902 	struct sockaddr_vm src, dst;
903 	struct vsock_sock *vsk;
904 	struct sock *sk;
905 	bool space_available;
906 
907 	vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid),
908 			le32_to_cpu(pkt->hdr.src_port));
909 	vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid),
910 			le32_to_cpu(pkt->hdr.dst_port));
911 
912 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
913 					dst.svm_cid, dst.svm_port,
914 					le32_to_cpu(pkt->hdr.len),
915 					le16_to_cpu(pkt->hdr.type),
916 					le16_to_cpu(pkt->hdr.op),
917 					le32_to_cpu(pkt->hdr.flags),
918 					le32_to_cpu(pkt->hdr.buf_alloc),
919 					le32_to_cpu(pkt->hdr.fwd_cnt));
920 
921 	if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
922 		(void)virtio_transport_reset_no_sock(pkt);
923 		goto free_pkt;
924 	}
925 
926 	/* The socket must be in connected or bound table
927 	 * otherwise send reset back
928 	 */
929 	sk = vsock_find_connected_socket(&src, &dst);
930 	if (!sk) {
931 		sk = vsock_find_bound_socket(&dst);
932 		if (!sk) {
933 			(void)virtio_transport_reset_no_sock(pkt);
934 			goto free_pkt;
935 		}
936 	}
937 
938 	vsk = vsock_sk(sk);
939 
940 	space_available = virtio_transport_space_update(sk, pkt);
941 
942 	lock_sock(sk);
943 
944 	/* Update CID in case it has changed after a transport reset event */
945 	vsk->local_addr.svm_cid = dst.svm_cid;
946 
947 	if (space_available)
948 		sk->sk_write_space(sk);
949 
950 	switch (sk->sk_state) {
951 	case VSOCK_SS_LISTEN:
952 		virtio_transport_recv_listen(sk, pkt);
953 		virtio_transport_free_pkt(pkt);
954 		break;
955 	case SS_CONNECTING:
956 		virtio_transport_recv_connecting(sk, pkt);
957 		virtio_transport_free_pkt(pkt);
958 		break;
959 	case SS_CONNECTED:
960 		virtio_transport_recv_connected(sk, pkt);
961 		break;
962 	case SS_DISCONNECTING:
963 		virtio_transport_recv_disconnecting(sk, pkt);
964 		virtio_transport_free_pkt(pkt);
965 		break;
966 	default:
967 		virtio_transport_free_pkt(pkt);
968 		break;
969 	}
970 	release_sock(sk);
971 
972 	/* Release refcnt obtained when we fetched this socket out of the
973 	 * bound or connected list.
974 	 */
975 	sock_put(sk);
976 	return;
977 
978 free_pkt:
979 	virtio_transport_free_pkt(pkt);
980 }
981 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
982 
983 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
984 {
985 	kfree(pkt->buf);
986 	kfree(pkt);
987 }
988 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
989 
990 MODULE_LICENSE("GPL v2");
991 MODULE_AUTHOR("Asias He");
992 MODULE_DESCRIPTION("common code for virtio vsock");
993