1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 
20 #include <net/sock.h>
21 #include <net/af_vsock.h>
22 
23 #define CREATE_TRACE_POINTS
24 #include <trace/events/vsock_virtio_transport_common.h>
25 
26 /* How long to wait for graceful shutdown of a connection */
27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
28 
29 static const struct virtio_transport *virtio_transport_get_ops(void)
30 {
31 	const struct vsock_transport *t = vsock_core_get_transport();
32 
33 	return container_of(t, struct virtio_transport, transport);
34 }
35 
36 static struct virtio_vsock_pkt *
37 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
38 			   size_t len,
39 			   u32 src_cid,
40 			   u32 src_port,
41 			   u32 dst_cid,
42 			   u32 dst_port)
43 {
44 	struct virtio_vsock_pkt *pkt;
45 	int err;
46 
47 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
48 	if (!pkt)
49 		return NULL;
50 
51 	pkt->hdr.type		= cpu_to_le16(info->type);
52 	pkt->hdr.op		= cpu_to_le16(info->op);
53 	pkt->hdr.src_cid	= cpu_to_le64(src_cid);
54 	pkt->hdr.dst_cid	= cpu_to_le64(dst_cid);
55 	pkt->hdr.src_port	= cpu_to_le32(src_port);
56 	pkt->hdr.dst_port	= cpu_to_le32(dst_port);
57 	pkt->hdr.flags		= cpu_to_le32(info->flags);
58 	pkt->len		= len;
59 	pkt->hdr.len		= cpu_to_le32(len);
60 	pkt->reply		= info->reply;
61 	pkt->vsk		= info->vsk;
62 
63 	if (info->msg && len > 0) {
64 		pkt->buf = kmalloc(len, GFP_KERNEL);
65 		if (!pkt->buf)
66 			goto out_pkt;
67 		err = memcpy_from_msg(pkt->buf, info->msg, len);
68 		if (err)
69 			goto out;
70 	}
71 
72 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
73 					 dst_cid, dst_port,
74 					 len,
75 					 info->type,
76 					 info->op,
77 					 info->flags);
78 
79 	return pkt;
80 
81 out:
82 	kfree(pkt->buf);
83 out_pkt:
84 	kfree(pkt);
85 	return NULL;
86 }
87 
88 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
89 					  struct virtio_vsock_pkt_info *info)
90 {
91 	u32 src_cid, src_port, dst_cid, dst_port;
92 	struct virtio_vsock_sock *vvs;
93 	struct virtio_vsock_pkt *pkt;
94 	u32 pkt_len = info->pkt_len;
95 
96 	src_cid = vm_sockets_get_local_cid();
97 	src_port = vsk->local_addr.svm_port;
98 	if (!info->remote_cid) {
99 		dst_cid	= vsk->remote_addr.svm_cid;
100 		dst_port = vsk->remote_addr.svm_port;
101 	} else {
102 		dst_cid = info->remote_cid;
103 		dst_port = info->remote_port;
104 	}
105 
106 	vvs = vsk->trans;
107 
108 	/* we can send less than pkt_len bytes */
109 	if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
110 		pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
111 
112 	/* virtio_transport_get_credit might return less than pkt_len credit */
113 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
114 
115 	/* Do not send zero length OP_RW pkt */
116 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
117 		return pkt_len;
118 
119 	pkt = virtio_transport_alloc_pkt(info, pkt_len,
120 					 src_cid, src_port,
121 					 dst_cid, dst_port);
122 	if (!pkt) {
123 		virtio_transport_put_credit(vvs, pkt_len);
124 		return -ENOMEM;
125 	}
126 
127 	virtio_transport_inc_tx_pkt(vvs, pkt);
128 
129 	return virtio_transport_get_ops()->send_pkt(pkt);
130 }
131 
132 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
133 					struct virtio_vsock_pkt *pkt)
134 {
135 	vvs->rx_bytes += pkt->len;
136 }
137 
138 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
139 					struct virtio_vsock_pkt *pkt)
140 {
141 	vvs->rx_bytes -= pkt->len;
142 	vvs->fwd_cnt += pkt->len;
143 }
144 
145 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
146 {
147 	spin_lock_bh(&vvs->tx_lock);
148 	pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
149 	pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
150 	spin_unlock_bh(&vvs->tx_lock);
151 }
152 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
153 
154 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
155 {
156 	u32 ret;
157 
158 	spin_lock_bh(&vvs->tx_lock);
159 	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
160 	if (ret > credit)
161 		ret = credit;
162 	vvs->tx_cnt += ret;
163 	spin_unlock_bh(&vvs->tx_lock);
164 
165 	return ret;
166 }
167 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
168 
169 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
170 {
171 	spin_lock_bh(&vvs->tx_lock);
172 	vvs->tx_cnt -= credit;
173 	spin_unlock_bh(&vvs->tx_lock);
174 }
175 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
176 
177 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
178 					       int type,
179 					       struct virtio_vsock_hdr *hdr)
180 {
181 	struct virtio_vsock_pkt_info info = {
182 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
183 		.type = type,
184 		.vsk = vsk,
185 	};
186 
187 	return virtio_transport_send_pkt_info(vsk, &info);
188 }
189 
190 static ssize_t
191 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
192 				   struct msghdr *msg,
193 				   size_t len)
194 {
195 	struct virtio_vsock_sock *vvs = vsk->trans;
196 	struct virtio_vsock_pkt *pkt;
197 	size_t bytes, total = 0;
198 	int err = -EFAULT;
199 
200 	spin_lock_bh(&vvs->rx_lock);
201 	while (total < len && !list_empty(&vvs->rx_queue)) {
202 		pkt = list_first_entry(&vvs->rx_queue,
203 				       struct virtio_vsock_pkt, list);
204 
205 		bytes = len - total;
206 		if (bytes > pkt->len - pkt->off)
207 			bytes = pkt->len - pkt->off;
208 
209 		/* sk_lock is held by caller so no one else can dequeue.
210 		 * Unlock rx_lock since memcpy_to_msg() may sleep.
211 		 */
212 		spin_unlock_bh(&vvs->rx_lock);
213 
214 		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
215 		if (err)
216 			goto out;
217 
218 		spin_lock_bh(&vvs->rx_lock);
219 
220 		total += bytes;
221 		pkt->off += bytes;
222 		if (pkt->off == pkt->len) {
223 			virtio_transport_dec_rx_pkt(vvs, pkt);
224 			list_del(&pkt->list);
225 			virtio_transport_free_pkt(pkt);
226 		}
227 	}
228 	spin_unlock_bh(&vvs->rx_lock);
229 
230 	/* Send a credit pkt to peer */
231 	virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
232 					    NULL);
233 
234 	return total;
235 
236 out:
237 	if (total)
238 		err = total;
239 	return err;
240 }
241 
242 ssize_t
243 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
244 				struct msghdr *msg,
245 				size_t len, int flags)
246 {
247 	if (flags & MSG_PEEK)
248 		return -EOPNOTSUPP;
249 
250 	return virtio_transport_stream_do_dequeue(vsk, msg, len);
251 }
252 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
253 
254 int
255 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
256 			       struct msghdr *msg,
257 			       size_t len, int flags)
258 {
259 	return -EOPNOTSUPP;
260 }
261 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
262 
263 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
264 {
265 	struct virtio_vsock_sock *vvs = vsk->trans;
266 	s64 bytes;
267 
268 	spin_lock_bh(&vvs->rx_lock);
269 	bytes = vvs->rx_bytes;
270 	spin_unlock_bh(&vvs->rx_lock);
271 
272 	return bytes;
273 }
274 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
275 
276 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
277 {
278 	struct virtio_vsock_sock *vvs = vsk->trans;
279 	s64 bytes;
280 
281 	bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
282 	if (bytes < 0)
283 		bytes = 0;
284 
285 	return bytes;
286 }
287 
288 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
289 {
290 	struct virtio_vsock_sock *vvs = vsk->trans;
291 	s64 bytes;
292 
293 	spin_lock_bh(&vvs->tx_lock);
294 	bytes = virtio_transport_has_space(vsk);
295 	spin_unlock_bh(&vvs->tx_lock);
296 
297 	return bytes;
298 }
299 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
300 
301 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
302 				    struct vsock_sock *psk)
303 {
304 	struct virtio_vsock_sock *vvs;
305 
306 	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
307 	if (!vvs)
308 		return -ENOMEM;
309 
310 	vsk->trans = vvs;
311 	vvs->vsk = vsk;
312 	if (psk) {
313 		struct virtio_vsock_sock *ptrans = psk->trans;
314 
315 		vvs->buf_size	= ptrans->buf_size;
316 		vvs->buf_size_min = ptrans->buf_size_min;
317 		vvs->buf_size_max = ptrans->buf_size_max;
318 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
319 	} else {
320 		vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
321 		vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
322 		vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
323 	}
324 
325 	vvs->buf_alloc = vvs->buf_size;
326 
327 	spin_lock_init(&vvs->rx_lock);
328 	spin_lock_init(&vvs->tx_lock);
329 	INIT_LIST_HEAD(&vvs->rx_queue);
330 
331 	return 0;
332 }
333 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
334 
335 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
336 {
337 	struct virtio_vsock_sock *vvs = vsk->trans;
338 
339 	return vvs->buf_size;
340 }
341 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
342 
343 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
344 {
345 	struct virtio_vsock_sock *vvs = vsk->trans;
346 
347 	return vvs->buf_size_min;
348 }
349 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
350 
351 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
352 {
353 	struct virtio_vsock_sock *vvs = vsk->trans;
354 
355 	return vvs->buf_size_max;
356 }
357 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
358 
359 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
360 {
361 	struct virtio_vsock_sock *vvs = vsk->trans;
362 
363 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
364 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
365 	if (val < vvs->buf_size_min)
366 		vvs->buf_size_min = val;
367 	if (val > vvs->buf_size_max)
368 		vvs->buf_size_max = val;
369 	vvs->buf_size = val;
370 	vvs->buf_alloc = val;
371 }
372 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
373 
374 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
375 {
376 	struct virtio_vsock_sock *vvs = vsk->trans;
377 
378 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
379 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
380 	if (val > vvs->buf_size)
381 		vvs->buf_size = val;
382 	vvs->buf_size_min = val;
383 }
384 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
385 
386 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
387 {
388 	struct virtio_vsock_sock *vvs = vsk->trans;
389 
390 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
391 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
392 	if (val < vvs->buf_size)
393 		vvs->buf_size = val;
394 	vvs->buf_size_max = val;
395 }
396 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
397 
398 int
399 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
400 				size_t target,
401 				bool *data_ready_now)
402 {
403 	if (vsock_stream_has_data(vsk))
404 		*data_ready_now = true;
405 	else
406 		*data_ready_now = false;
407 
408 	return 0;
409 }
410 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
411 
412 int
413 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
414 				 size_t target,
415 				 bool *space_avail_now)
416 {
417 	s64 free_space;
418 
419 	free_space = vsock_stream_has_space(vsk);
420 	if (free_space > 0)
421 		*space_avail_now = true;
422 	else if (free_space == 0)
423 		*space_avail_now = false;
424 
425 	return 0;
426 }
427 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
428 
429 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
430 	size_t target, struct vsock_transport_recv_notify_data *data)
431 {
432 	return 0;
433 }
434 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
435 
436 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
437 	size_t target, struct vsock_transport_recv_notify_data *data)
438 {
439 	return 0;
440 }
441 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
442 
443 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
444 	size_t target, struct vsock_transport_recv_notify_data *data)
445 {
446 	return 0;
447 }
448 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
449 
450 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
451 	size_t target, ssize_t copied, bool data_read,
452 	struct vsock_transport_recv_notify_data *data)
453 {
454 	return 0;
455 }
456 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
457 
458 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
459 	struct vsock_transport_send_notify_data *data)
460 {
461 	return 0;
462 }
463 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
464 
465 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
466 	struct vsock_transport_send_notify_data *data)
467 {
468 	return 0;
469 }
470 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
471 
472 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
473 	struct vsock_transport_send_notify_data *data)
474 {
475 	return 0;
476 }
477 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
478 
479 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
480 	ssize_t written, struct vsock_transport_send_notify_data *data)
481 {
482 	return 0;
483 }
484 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
485 
486 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
487 {
488 	struct virtio_vsock_sock *vvs = vsk->trans;
489 
490 	return vvs->buf_size;
491 }
492 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
493 
494 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
495 {
496 	return true;
497 }
498 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
499 
500 bool virtio_transport_stream_allow(u32 cid, u32 port)
501 {
502 	return true;
503 }
504 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
505 
506 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
507 				struct sockaddr_vm *addr)
508 {
509 	return -EOPNOTSUPP;
510 }
511 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
512 
513 bool virtio_transport_dgram_allow(u32 cid, u32 port)
514 {
515 	return false;
516 }
517 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
518 
519 int virtio_transport_connect(struct vsock_sock *vsk)
520 {
521 	struct virtio_vsock_pkt_info info = {
522 		.op = VIRTIO_VSOCK_OP_REQUEST,
523 		.type = VIRTIO_VSOCK_TYPE_STREAM,
524 		.vsk = vsk,
525 	};
526 
527 	return virtio_transport_send_pkt_info(vsk, &info);
528 }
529 EXPORT_SYMBOL_GPL(virtio_transport_connect);
530 
531 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
532 {
533 	struct virtio_vsock_pkt_info info = {
534 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
535 		.type = VIRTIO_VSOCK_TYPE_STREAM,
536 		.flags = (mode & RCV_SHUTDOWN ?
537 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
538 			 (mode & SEND_SHUTDOWN ?
539 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
540 		.vsk = vsk,
541 	};
542 
543 	return virtio_transport_send_pkt_info(vsk, &info);
544 }
545 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
546 
547 int
548 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
549 			       struct sockaddr_vm *remote_addr,
550 			       struct msghdr *msg,
551 			       size_t dgram_len)
552 {
553 	return -EOPNOTSUPP;
554 }
555 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
556 
557 ssize_t
558 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
559 				struct msghdr *msg,
560 				size_t len)
561 {
562 	struct virtio_vsock_pkt_info info = {
563 		.op = VIRTIO_VSOCK_OP_RW,
564 		.type = VIRTIO_VSOCK_TYPE_STREAM,
565 		.msg = msg,
566 		.pkt_len = len,
567 		.vsk = vsk,
568 	};
569 
570 	return virtio_transport_send_pkt_info(vsk, &info);
571 }
572 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
573 
574 void virtio_transport_destruct(struct vsock_sock *vsk)
575 {
576 	struct virtio_vsock_sock *vvs = vsk->trans;
577 
578 	kfree(vvs);
579 }
580 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
581 
582 static int virtio_transport_reset(struct vsock_sock *vsk,
583 				  struct virtio_vsock_pkt *pkt)
584 {
585 	struct virtio_vsock_pkt_info info = {
586 		.op = VIRTIO_VSOCK_OP_RST,
587 		.type = VIRTIO_VSOCK_TYPE_STREAM,
588 		.reply = !!pkt,
589 		.vsk = vsk,
590 	};
591 
592 	/* Send RST only if the original pkt is not a RST pkt */
593 	if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
594 		return 0;
595 
596 	return virtio_transport_send_pkt_info(vsk, &info);
597 }
598 
599 /* Normally packets are associated with a socket.  There may be no socket if an
600  * attempt was made to connect to a socket that does not exist.
601  */
602 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
603 {
604 	struct virtio_vsock_pkt_info info = {
605 		.op = VIRTIO_VSOCK_OP_RST,
606 		.type = le16_to_cpu(pkt->hdr.type),
607 		.reply = true,
608 	};
609 
610 	/* Send RST only if the original pkt is not a RST pkt */
611 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
612 		return 0;
613 
614 	pkt = virtio_transport_alloc_pkt(&info, 0,
615 					 le64_to_cpu(pkt->hdr.dst_cid),
616 					 le32_to_cpu(pkt->hdr.dst_port),
617 					 le64_to_cpu(pkt->hdr.src_cid),
618 					 le32_to_cpu(pkt->hdr.src_port));
619 	if (!pkt)
620 		return -ENOMEM;
621 
622 	return virtio_transport_get_ops()->send_pkt(pkt);
623 }
624 
625 static void virtio_transport_wait_close(struct sock *sk, long timeout)
626 {
627 	if (timeout) {
628 		DEFINE_WAIT_FUNC(wait, woken_wake_function);
629 
630 		add_wait_queue(sk_sleep(sk), &wait);
631 
632 		do {
633 			if (sk_wait_event(sk, &timeout,
634 					  sock_flag(sk, SOCK_DONE), &wait))
635 				break;
636 		} while (!signal_pending(current) && timeout);
637 
638 		remove_wait_queue(sk_sleep(sk), &wait);
639 	}
640 }
641 
642 static void virtio_transport_do_close(struct vsock_sock *vsk,
643 				      bool cancel_timeout)
644 {
645 	struct sock *sk = sk_vsock(vsk);
646 
647 	sock_set_flag(sk, SOCK_DONE);
648 	vsk->peer_shutdown = SHUTDOWN_MASK;
649 	if (vsock_stream_has_data(vsk) <= 0)
650 		sk->sk_state = SS_DISCONNECTING;
651 	sk->sk_state_change(sk);
652 
653 	if (vsk->close_work_scheduled &&
654 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
655 		vsk->close_work_scheduled = false;
656 
657 		vsock_remove_sock(vsk);
658 
659 		/* Release refcnt obtained when we scheduled the timeout */
660 		sock_put(sk);
661 	}
662 }
663 
664 static void virtio_transport_close_timeout(struct work_struct *work)
665 {
666 	struct vsock_sock *vsk =
667 		container_of(work, struct vsock_sock, close_work.work);
668 	struct sock *sk = sk_vsock(vsk);
669 
670 	sock_hold(sk);
671 	lock_sock(sk);
672 
673 	if (!sock_flag(sk, SOCK_DONE)) {
674 		(void)virtio_transport_reset(vsk, NULL);
675 
676 		virtio_transport_do_close(vsk, false);
677 	}
678 
679 	vsk->close_work_scheduled = false;
680 
681 	release_sock(sk);
682 	sock_put(sk);
683 }
684 
685 /* User context, vsk->sk is locked */
686 static bool virtio_transport_close(struct vsock_sock *vsk)
687 {
688 	struct sock *sk = &vsk->sk;
689 
690 	if (!(sk->sk_state == SS_CONNECTED ||
691 	      sk->sk_state == SS_DISCONNECTING))
692 		return true;
693 
694 	/* Already received SHUTDOWN from peer, reply with RST */
695 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
696 		(void)virtio_transport_reset(vsk, NULL);
697 		return true;
698 	}
699 
700 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
701 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
702 
703 	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
704 		virtio_transport_wait_close(sk, sk->sk_lingertime);
705 
706 	if (sock_flag(sk, SOCK_DONE)) {
707 		return true;
708 	}
709 
710 	sock_hold(sk);
711 	INIT_DELAYED_WORK(&vsk->close_work,
712 			  virtio_transport_close_timeout);
713 	vsk->close_work_scheduled = true;
714 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
715 	return false;
716 }
717 
718 void virtio_transport_release(struct vsock_sock *vsk)
719 {
720 	struct sock *sk = &vsk->sk;
721 	bool remove_sock = true;
722 
723 	lock_sock(sk);
724 	if (sk->sk_type == SOCK_STREAM)
725 		remove_sock = virtio_transport_close(vsk);
726 	release_sock(sk);
727 
728 	if (remove_sock)
729 		vsock_remove_sock(vsk);
730 }
731 EXPORT_SYMBOL_GPL(virtio_transport_release);
732 
733 static int
734 virtio_transport_recv_connecting(struct sock *sk,
735 				 struct virtio_vsock_pkt *pkt)
736 {
737 	struct vsock_sock *vsk = vsock_sk(sk);
738 	int err;
739 	int skerr;
740 
741 	switch (le16_to_cpu(pkt->hdr.op)) {
742 	case VIRTIO_VSOCK_OP_RESPONSE:
743 		sk->sk_state = SS_CONNECTED;
744 		sk->sk_socket->state = SS_CONNECTED;
745 		vsock_insert_connected(vsk);
746 		sk->sk_state_change(sk);
747 		break;
748 	case VIRTIO_VSOCK_OP_INVALID:
749 		break;
750 	case VIRTIO_VSOCK_OP_RST:
751 		skerr = ECONNRESET;
752 		err = 0;
753 		goto destroy;
754 	default:
755 		skerr = EPROTO;
756 		err = -EINVAL;
757 		goto destroy;
758 	}
759 	return 0;
760 
761 destroy:
762 	virtio_transport_reset(vsk, pkt);
763 	sk->sk_state = SS_UNCONNECTED;
764 	sk->sk_err = skerr;
765 	sk->sk_error_report(sk);
766 	return err;
767 }
768 
769 static int
770 virtio_transport_recv_connected(struct sock *sk,
771 				struct virtio_vsock_pkt *pkt)
772 {
773 	struct vsock_sock *vsk = vsock_sk(sk);
774 	struct virtio_vsock_sock *vvs = vsk->trans;
775 	int err = 0;
776 
777 	switch (le16_to_cpu(pkt->hdr.op)) {
778 	case VIRTIO_VSOCK_OP_RW:
779 		pkt->len = le32_to_cpu(pkt->hdr.len);
780 		pkt->off = 0;
781 
782 		spin_lock_bh(&vvs->rx_lock);
783 		virtio_transport_inc_rx_pkt(vvs, pkt);
784 		list_add_tail(&pkt->list, &vvs->rx_queue);
785 		spin_unlock_bh(&vvs->rx_lock);
786 
787 		sk->sk_data_ready(sk);
788 		return err;
789 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
790 		sk->sk_write_space(sk);
791 		break;
792 	case VIRTIO_VSOCK_OP_SHUTDOWN:
793 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
794 			vsk->peer_shutdown |= RCV_SHUTDOWN;
795 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
796 			vsk->peer_shutdown |= SEND_SHUTDOWN;
797 		if (vsk->peer_shutdown == SHUTDOWN_MASK &&
798 		    vsock_stream_has_data(vsk) <= 0)
799 			sk->sk_state = SS_DISCONNECTING;
800 		if (le32_to_cpu(pkt->hdr.flags))
801 			sk->sk_state_change(sk);
802 		break;
803 	case VIRTIO_VSOCK_OP_RST:
804 		virtio_transport_do_close(vsk, true);
805 		break;
806 	default:
807 		err = -EINVAL;
808 		break;
809 	}
810 
811 	virtio_transport_free_pkt(pkt);
812 	return err;
813 }
814 
815 static void
816 virtio_transport_recv_disconnecting(struct sock *sk,
817 				    struct virtio_vsock_pkt *pkt)
818 {
819 	struct vsock_sock *vsk = vsock_sk(sk);
820 
821 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
822 		virtio_transport_do_close(vsk, true);
823 }
824 
825 static int
826 virtio_transport_send_response(struct vsock_sock *vsk,
827 			       struct virtio_vsock_pkt *pkt)
828 {
829 	struct virtio_vsock_pkt_info info = {
830 		.op = VIRTIO_VSOCK_OP_RESPONSE,
831 		.type = VIRTIO_VSOCK_TYPE_STREAM,
832 		.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
833 		.remote_port = le32_to_cpu(pkt->hdr.src_port),
834 		.reply = true,
835 		.vsk = vsk,
836 	};
837 
838 	return virtio_transport_send_pkt_info(vsk, &info);
839 }
840 
841 /* Handle server socket */
842 static int
843 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
844 {
845 	struct vsock_sock *vsk = vsock_sk(sk);
846 	struct vsock_sock *vchild;
847 	struct sock *child;
848 
849 	if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
850 		virtio_transport_reset(vsk, pkt);
851 		return -EINVAL;
852 	}
853 
854 	if (sk_acceptq_is_full(sk)) {
855 		virtio_transport_reset(vsk, pkt);
856 		return -ENOMEM;
857 	}
858 
859 	child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
860 			       sk->sk_type, 0);
861 	if (!child) {
862 		virtio_transport_reset(vsk, pkt);
863 		return -ENOMEM;
864 	}
865 
866 	sk->sk_ack_backlog++;
867 
868 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
869 
870 	child->sk_state = SS_CONNECTED;
871 
872 	vchild = vsock_sk(child);
873 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
874 			le32_to_cpu(pkt->hdr.dst_port));
875 	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
876 			le32_to_cpu(pkt->hdr.src_port));
877 
878 	vsock_insert_connected(vchild);
879 	vsock_enqueue_accept(sk, child);
880 	virtio_transport_send_response(vchild, pkt);
881 
882 	release_sock(child);
883 
884 	sk->sk_data_ready(sk);
885 	return 0;
886 }
887 
888 static bool virtio_transport_space_update(struct sock *sk,
889 					  struct virtio_vsock_pkt *pkt)
890 {
891 	struct vsock_sock *vsk = vsock_sk(sk);
892 	struct virtio_vsock_sock *vvs = vsk->trans;
893 	bool space_available;
894 
895 	/* buf_alloc and fwd_cnt is always included in the hdr */
896 	spin_lock_bh(&vvs->tx_lock);
897 	vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
898 	vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
899 	space_available = virtio_transport_has_space(vsk);
900 	spin_unlock_bh(&vvs->tx_lock);
901 	return space_available;
902 }
903 
904 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
905  * lock.
906  */
907 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
908 {
909 	struct sockaddr_vm src, dst;
910 	struct vsock_sock *vsk;
911 	struct sock *sk;
912 	bool space_available;
913 
914 	vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
915 			le32_to_cpu(pkt->hdr.src_port));
916 	vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
917 			le32_to_cpu(pkt->hdr.dst_port));
918 
919 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
920 					dst.svm_cid, dst.svm_port,
921 					le32_to_cpu(pkt->hdr.len),
922 					le16_to_cpu(pkt->hdr.type),
923 					le16_to_cpu(pkt->hdr.op),
924 					le32_to_cpu(pkt->hdr.flags),
925 					le32_to_cpu(pkt->hdr.buf_alloc),
926 					le32_to_cpu(pkt->hdr.fwd_cnt));
927 
928 	if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
929 		(void)virtio_transport_reset_no_sock(pkt);
930 		goto free_pkt;
931 	}
932 
933 	/* The socket must be in connected or bound table
934 	 * otherwise send reset back
935 	 */
936 	sk = vsock_find_connected_socket(&src, &dst);
937 	if (!sk) {
938 		sk = vsock_find_bound_socket(&dst);
939 		if (!sk) {
940 			(void)virtio_transport_reset_no_sock(pkt);
941 			goto free_pkt;
942 		}
943 	}
944 
945 	vsk = vsock_sk(sk);
946 
947 	space_available = virtio_transport_space_update(sk, pkt);
948 
949 	lock_sock(sk);
950 
951 	/* Update CID in case it has changed after a transport reset event */
952 	vsk->local_addr.svm_cid = dst.svm_cid;
953 
954 	if (space_available)
955 		sk->sk_write_space(sk);
956 
957 	switch (sk->sk_state) {
958 	case VSOCK_SS_LISTEN:
959 		virtio_transport_recv_listen(sk, pkt);
960 		virtio_transport_free_pkt(pkt);
961 		break;
962 	case SS_CONNECTING:
963 		virtio_transport_recv_connecting(sk, pkt);
964 		virtio_transport_free_pkt(pkt);
965 		break;
966 	case SS_CONNECTED:
967 		virtio_transport_recv_connected(sk, pkt);
968 		break;
969 	case SS_DISCONNECTING:
970 		virtio_transport_recv_disconnecting(sk, pkt);
971 		virtio_transport_free_pkt(pkt);
972 		break;
973 	default:
974 		virtio_transport_free_pkt(pkt);
975 		break;
976 	}
977 	release_sock(sk);
978 
979 	/* Release refcnt obtained when we fetched this socket out of the
980 	 * bound or connected list.
981 	 */
982 	sock_put(sk);
983 	return;
984 
985 free_pkt:
986 	virtio_transport_free_pkt(pkt);
987 }
988 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
989 
990 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
991 {
992 	kfree(pkt->buf);
993 	kfree(pkt);
994 }
995 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
996 
997 MODULE_LICENSE("GPL v2");
998 MODULE_AUTHOR("Asias He");
999 MODULE_DESCRIPTION("common code for virtio vsock");
1000