1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18 
19 #include <net/sock.h>
20 #include <net/af_vsock.h>
21 
22 #define CREATE_TRACE_POINTS
23 #include <trace/events/vsock_virtio_transport_common.h>
24 
25 /* How long to wait for graceful shutdown of a connection */
26 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
27 
28 static const struct virtio_transport *virtio_transport_get_ops(void)
29 {
30 	const struct vsock_transport *t = vsock_core_get_transport();
31 
32 	return container_of(t, struct virtio_transport, transport);
33 }
34 
35 static struct virtio_vsock_pkt *
36 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
37 			   size_t len,
38 			   u32 src_cid,
39 			   u32 src_port,
40 			   u32 dst_cid,
41 			   u32 dst_port)
42 {
43 	struct virtio_vsock_pkt *pkt;
44 	int err;
45 
46 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
47 	if (!pkt)
48 		return NULL;
49 
50 	pkt->hdr.type		= cpu_to_le16(info->type);
51 	pkt->hdr.op		= cpu_to_le16(info->op);
52 	pkt->hdr.src_cid	= cpu_to_le64(src_cid);
53 	pkt->hdr.dst_cid	= cpu_to_le64(dst_cid);
54 	pkt->hdr.src_port	= cpu_to_le32(src_port);
55 	pkt->hdr.dst_port	= cpu_to_le32(dst_port);
56 	pkt->hdr.flags		= cpu_to_le32(info->flags);
57 	pkt->len		= len;
58 	pkt->hdr.len		= cpu_to_le32(len);
59 	pkt->reply		= info->reply;
60 
61 	if (info->msg && len > 0) {
62 		pkt->buf = kmalloc(len, GFP_KERNEL);
63 		if (!pkt->buf)
64 			goto out_pkt;
65 		err = memcpy_from_msg(pkt->buf, info->msg, len);
66 		if (err)
67 			goto out;
68 	}
69 
70 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
71 					 dst_cid, dst_port,
72 					 len,
73 					 info->type,
74 					 info->op,
75 					 info->flags);
76 
77 	return pkt;
78 
79 out:
80 	kfree(pkt->buf);
81 out_pkt:
82 	kfree(pkt);
83 	return NULL;
84 }
85 
86 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
87 					  struct virtio_vsock_pkt_info *info)
88 {
89 	u32 src_cid, src_port, dst_cid, dst_port;
90 	struct virtio_vsock_sock *vvs;
91 	struct virtio_vsock_pkt *pkt;
92 	u32 pkt_len = info->pkt_len;
93 
94 	src_cid = vm_sockets_get_local_cid();
95 	src_port = vsk->local_addr.svm_port;
96 	if (!info->remote_cid) {
97 		dst_cid	= vsk->remote_addr.svm_cid;
98 		dst_port = vsk->remote_addr.svm_port;
99 	} else {
100 		dst_cid = info->remote_cid;
101 		dst_port = info->remote_port;
102 	}
103 
104 	vvs = vsk->trans;
105 
106 	/* we can send less than pkt_len bytes */
107 	if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
108 		pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
109 
110 	/* virtio_transport_get_credit might return less than pkt_len credit */
111 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
112 
113 	/* Do not send zero length OP_RW pkt */
114 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
115 		return pkt_len;
116 
117 	pkt = virtio_transport_alloc_pkt(info, pkt_len,
118 					 src_cid, src_port,
119 					 dst_cid, dst_port);
120 	if (!pkt) {
121 		virtio_transport_put_credit(vvs, pkt_len);
122 		return -ENOMEM;
123 	}
124 
125 	virtio_transport_inc_tx_pkt(vvs, pkt);
126 
127 	return virtio_transport_get_ops()->send_pkt(pkt);
128 }
129 
130 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
131 					struct virtio_vsock_pkt *pkt)
132 {
133 	vvs->rx_bytes += pkt->len;
134 }
135 
136 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
137 					struct virtio_vsock_pkt *pkt)
138 {
139 	vvs->rx_bytes -= pkt->len;
140 	vvs->fwd_cnt += pkt->len;
141 }
142 
143 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
144 {
145 	spin_lock_bh(&vvs->tx_lock);
146 	pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
147 	pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
148 	spin_unlock_bh(&vvs->tx_lock);
149 }
150 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
151 
152 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
153 {
154 	u32 ret;
155 
156 	spin_lock_bh(&vvs->tx_lock);
157 	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
158 	if (ret > credit)
159 		ret = credit;
160 	vvs->tx_cnt += ret;
161 	spin_unlock_bh(&vvs->tx_lock);
162 
163 	return ret;
164 }
165 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
166 
167 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
168 {
169 	spin_lock_bh(&vvs->tx_lock);
170 	vvs->tx_cnt -= credit;
171 	spin_unlock_bh(&vvs->tx_lock);
172 }
173 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
174 
175 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
176 					       int type,
177 					       struct virtio_vsock_hdr *hdr)
178 {
179 	struct virtio_vsock_pkt_info info = {
180 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
181 		.type = type,
182 	};
183 
184 	return virtio_transport_send_pkt_info(vsk, &info);
185 }
186 
187 static ssize_t
188 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
189 				   struct msghdr *msg,
190 				   size_t len)
191 {
192 	struct virtio_vsock_sock *vvs = vsk->trans;
193 	struct virtio_vsock_pkt *pkt;
194 	size_t bytes, total = 0;
195 	int err = -EFAULT;
196 
197 	spin_lock_bh(&vvs->rx_lock);
198 	while (total < len && !list_empty(&vvs->rx_queue)) {
199 		pkt = list_first_entry(&vvs->rx_queue,
200 				       struct virtio_vsock_pkt, list);
201 
202 		bytes = len - total;
203 		if (bytes > pkt->len - pkt->off)
204 			bytes = pkt->len - pkt->off;
205 
206 		/* sk_lock is held by caller so no one else can dequeue.
207 		 * Unlock rx_lock since memcpy_to_msg() may sleep.
208 		 */
209 		spin_unlock_bh(&vvs->rx_lock);
210 
211 		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
212 		if (err)
213 			goto out;
214 
215 		spin_lock_bh(&vvs->rx_lock);
216 
217 		total += bytes;
218 		pkt->off += bytes;
219 		if (pkt->off == pkt->len) {
220 			virtio_transport_dec_rx_pkt(vvs, pkt);
221 			list_del(&pkt->list);
222 			virtio_transport_free_pkt(pkt);
223 		}
224 	}
225 	spin_unlock_bh(&vvs->rx_lock);
226 
227 	/* Send a credit pkt to peer */
228 	virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
229 					    NULL);
230 
231 	return total;
232 
233 out:
234 	if (total)
235 		err = total;
236 	return err;
237 }
238 
239 ssize_t
240 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
241 				struct msghdr *msg,
242 				size_t len, int flags)
243 {
244 	if (flags & MSG_PEEK)
245 		return -EOPNOTSUPP;
246 
247 	return virtio_transport_stream_do_dequeue(vsk, msg, len);
248 }
249 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
250 
251 int
252 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
253 			       struct msghdr *msg,
254 			       size_t len, int flags)
255 {
256 	return -EOPNOTSUPP;
257 }
258 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
259 
260 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
261 {
262 	struct virtio_vsock_sock *vvs = vsk->trans;
263 	s64 bytes;
264 
265 	spin_lock_bh(&vvs->rx_lock);
266 	bytes = vvs->rx_bytes;
267 	spin_unlock_bh(&vvs->rx_lock);
268 
269 	return bytes;
270 }
271 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
272 
273 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
274 {
275 	struct virtio_vsock_sock *vvs = vsk->trans;
276 	s64 bytes;
277 
278 	bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
279 	if (bytes < 0)
280 		bytes = 0;
281 
282 	return bytes;
283 }
284 
285 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
286 {
287 	struct virtio_vsock_sock *vvs = vsk->trans;
288 	s64 bytes;
289 
290 	spin_lock_bh(&vvs->tx_lock);
291 	bytes = virtio_transport_has_space(vsk);
292 	spin_unlock_bh(&vvs->tx_lock);
293 
294 	return bytes;
295 }
296 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
297 
298 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
299 				    struct vsock_sock *psk)
300 {
301 	struct virtio_vsock_sock *vvs;
302 
303 	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
304 	if (!vvs)
305 		return -ENOMEM;
306 
307 	vsk->trans = vvs;
308 	vvs->vsk = vsk;
309 	if (psk) {
310 		struct virtio_vsock_sock *ptrans = psk->trans;
311 
312 		vvs->buf_size	= ptrans->buf_size;
313 		vvs->buf_size_min = ptrans->buf_size_min;
314 		vvs->buf_size_max = ptrans->buf_size_max;
315 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
316 	} else {
317 		vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
318 		vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
319 		vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
320 	}
321 
322 	vvs->buf_alloc = vvs->buf_size;
323 
324 	spin_lock_init(&vvs->rx_lock);
325 	spin_lock_init(&vvs->tx_lock);
326 	INIT_LIST_HEAD(&vvs->rx_queue);
327 
328 	return 0;
329 }
330 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
331 
332 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
333 {
334 	struct virtio_vsock_sock *vvs = vsk->trans;
335 
336 	return vvs->buf_size;
337 }
338 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
339 
340 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
341 {
342 	struct virtio_vsock_sock *vvs = vsk->trans;
343 
344 	return vvs->buf_size_min;
345 }
346 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
347 
348 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
349 {
350 	struct virtio_vsock_sock *vvs = vsk->trans;
351 
352 	return vvs->buf_size_max;
353 }
354 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
355 
356 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
357 {
358 	struct virtio_vsock_sock *vvs = vsk->trans;
359 
360 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
361 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
362 	if (val < vvs->buf_size_min)
363 		vvs->buf_size_min = val;
364 	if (val > vvs->buf_size_max)
365 		vvs->buf_size_max = val;
366 	vvs->buf_size = val;
367 	vvs->buf_alloc = val;
368 }
369 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
370 
371 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
372 {
373 	struct virtio_vsock_sock *vvs = vsk->trans;
374 
375 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
376 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
377 	if (val > vvs->buf_size)
378 		vvs->buf_size = val;
379 	vvs->buf_size_min = val;
380 }
381 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
382 
383 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
384 {
385 	struct virtio_vsock_sock *vvs = vsk->trans;
386 
387 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
388 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
389 	if (val < vvs->buf_size)
390 		vvs->buf_size = val;
391 	vvs->buf_size_max = val;
392 }
393 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
394 
395 int
396 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
397 				size_t target,
398 				bool *data_ready_now)
399 {
400 	if (vsock_stream_has_data(vsk))
401 		*data_ready_now = true;
402 	else
403 		*data_ready_now = false;
404 
405 	return 0;
406 }
407 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
408 
409 int
410 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
411 				 size_t target,
412 				 bool *space_avail_now)
413 {
414 	s64 free_space;
415 
416 	free_space = vsock_stream_has_space(vsk);
417 	if (free_space > 0)
418 		*space_avail_now = true;
419 	else if (free_space == 0)
420 		*space_avail_now = false;
421 
422 	return 0;
423 }
424 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
425 
426 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
427 	size_t target, struct vsock_transport_recv_notify_data *data)
428 {
429 	return 0;
430 }
431 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
432 
433 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
434 	size_t target, struct vsock_transport_recv_notify_data *data)
435 {
436 	return 0;
437 }
438 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
439 
440 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
441 	size_t target, struct vsock_transport_recv_notify_data *data)
442 {
443 	return 0;
444 }
445 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
446 
447 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
448 	size_t target, ssize_t copied, bool data_read,
449 	struct vsock_transport_recv_notify_data *data)
450 {
451 	return 0;
452 }
453 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
454 
455 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
456 	struct vsock_transport_send_notify_data *data)
457 {
458 	return 0;
459 }
460 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
461 
462 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
463 	struct vsock_transport_send_notify_data *data)
464 {
465 	return 0;
466 }
467 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
468 
469 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
470 	struct vsock_transport_send_notify_data *data)
471 {
472 	return 0;
473 }
474 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
475 
476 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
477 	ssize_t written, struct vsock_transport_send_notify_data *data)
478 {
479 	return 0;
480 }
481 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
482 
483 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
484 {
485 	struct virtio_vsock_sock *vvs = vsk->trans;
486 
487 	return vvs->buf_size;
488 }
489 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
490 
491 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
492 {
493 	return true;
494 }
495 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
496 
497 bool virtio_transport_stream_allow(u32 cid, u32 port)
498 {
499 	return true;
500 }
501 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
502 
503 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
504 				struct sockaddr_vm *addr)
505 {
506 	return -EOPNOTSUPP;
507 }
508 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
509 
510 bool virtio_transport_dgram_allow(u32 cid, u32 port)
511 {
512 	return false;
513 }
514 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
515 
516 int virtio_transport_connect(struct vsock_sock *vsk)
517 {
518 	struct virtio_vsock_pkt_info info = {
519 		.op = VIRTIO_VSOCK_OP_REQUEST,
520 		.type = VIRTIO_VSOCK_TYPE_STREAM,
521 	};
522 
523 	return virtio_transport_send_pkt_info(vsk, &info);
524 }
525 EXPORT_SYMBOL_GPL(virtio_transport_connect);
526 
527 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
528 {
529 	struct virtio_vsock_pkt_info info = {
530 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
531 		.type = VIRTIO_VSOCK_TYPE_STREAM,
532 		.flags = (mode & RCV_SHUTDOWN ?
533 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
534 			 (mode & SEND_SHUTDOWN ?
535 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
536 	};
537 
538 	return virtio_transport_send_pkt_info(vsk, &info);
539 }
540 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
541 
542 int
543 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
544 			       struct sockaddr_vm *remote_addr,
545 			       struct msghdr *msg,
546 			       size_t dgram_len)
547 {
548 	return -EOPNOTSUPP;
549 }
550 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
551 
552 ssize_t
553 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
554 				struct msghdr *msg,
555 				size_t len)
556 {
557 	struct virtio_vsock_pkt_info info = {
558 		.op = VIRTIO_VSOCK_OP_RW,
559 		.type = VIRTIO_VSOCK_TYPE_STREAM,
560 		.msg = msg,
561 		.pkt_len = len,
562 	};
563 
564 	return virtio_transport_send_pkt_info(vsk, &info);
565 }
566 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
567 
568 void virtio_transport_destruct(struct vsock_sock *vsk)
569 {
570 	struct virtio_vsock_sock *vvs = vsk->trans;
571 
572 	kfree(vvs);
573 }
574 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
575 
576 static int virtio_transport_reset(struct vsock_sock *vsk,
577 				  struct virtio_vsock_pkt *pkt)
578 {
579 	struct virtio_vsock_pkt_info info = {
580 		.op = VIRTIO_VSOCK_OP_RST,
581 		.type = VIRTIO_VSOCK_TYPE_STREAM,
582 		.reply = !!pkt,
583 	};
584 
585 	/* Send RST only if the original pkt is not a RST pkt */
586 	if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
587 		return 0;
588 
589 	return virtio_transport_send_pkt_info(vsk, &info);
590 }
591 
592 /* Normally packets are associated with a socket.  There may be no socket if an
593  * attempt was made to connect to a socket that does not exist.
594  */
595 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
596 {
597 	struct virtio_vsock_pkt_info info = {
598 		.op = VIRTIO_VSOCK_OP_RST,
599 		.type = le16_to_cpu(pkt->hdr.type),
600 		.reply = true,
601 	};
602 
603 	/* Send RST only if the original pkt is not a RST pkt */
604 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
605 		return 0;
606 
607 	pkt = virtio_transport_alloc_pkt(&info, 0,
608 					 le64_to_cpu(pkt->hdr.dst_cid),
609 					 le32_to_cpu(pkt->hdr.dst_port),
610 					 le64_to_cpu(pkt->hdr.src_cid),
611 					 le32_to_cpu(pkt->hdr.src_port));
612 	if (!pkt)
613 		return -ENOMEM;
614 
615 	return virtio_transport_get_ops()->send_pkt(pkt);
616 }
617 
618 static void virtio_transport_wait_close(struct sock *sk, long timeout)
619 {
620 	if (timeout) {
621 		DEFINE_WAIT_FUNC(wait, woken_wake_function);
622 
623 		add_wait_queue(sk_sleep(sk), &wait);
624 
625 		do {
626 			if (sk_wait_event(sk, &timeout,
627 					  sock_flag(sk, SOCK_DONE), &wait))
628 				break;
629 		} while (!signal_pending(current) && timeout);
630 
631 		remove_wait_queue(sk_sleep(sk), &wait);
632 	}
633 }
634 
635 static void virtio_transport_do_close(struct vsock_sock *vsk,
636 				      bool cancel_timeout)
637 {
638 	struct sock *sk = sk_vsock(vsk);
639 
640 	sock_set_flag(sk, SOCK_DONE);
641 	vsk->peer_shutdown = SHUTDOWN_MASK;
642 	if (vsock_stream_has_data(vsk) <= 0)
643 		sk->sk_state = SS_DISCONNECTING;
644 	sk->sk_state_change(sk);
645 
646 	if (vsk->close_work_scheduled &&
647 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
648 		vsk->close_work_scheduled = false;
649 
650 		vsock_remove_sock(vsk);
651 
652 		/* Release refcnt obtained when we scheduled the timeout */
653 		sock_put(sk);
654 	}
655 }
656 
657 static void virtio_transport_close_timeout(struct work_struct *work)
658 {
659 	struct vsock_sock *vsk =
660 		container_of(work, struct vsock_sock, close_work.work);
661 	struct sock *sk = sk_vsock(vsk);
662 
663 	sock_hold(sk);
664 	lock_sock(sk);
665 
666 	if (!sock_flag(sk, SOCK_DONE)) {
667 		(void)virtio_transport_reset(vsk, NULL);
668 
669 		virtio_transport_do_close(vsk, false);
670 	}
671 
672 	vsk->close_work_scheduled = false;
673 
674 	release_sock(sk);
675 	sock_put(sk);
676 }
677 
678 /* User context, vsk->sk is locked */
679 static bool virtio_transport_close(struct vsock_sock *vsk)
680 {
681 	struct sock *sk = &vsk->sk;
682 
683 	if (!(sk->sk_state == SS_CONNECTED ||
684 	      sk->sk_state == SS_DISCONNECTING))
685 		return true;
686 
687 	/* Already received SHUTDOWN from peer, reply with RST */
688 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
689 		(void)virtio_transport_reset(vsk, NULL);
690 		return true;
691 	}
692 
693 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
694 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
695 
696 	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
697 		virtio_transport_wait_close(sk, sk->sk_lingertime);
698 
699 	if (sock_flag(sk, SOCK_DONE)) {
700 		return true;
701 	}
702 
703 	sock_hold(sk);
704 	INIT_DELAYED_WORK(&vsk->close_work,
705 			  virtio_transport_close_timeout);
706 	vsk->close_work_scheduled = true;
707 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
708 	return false;
709 }
710 
711 void virtio_transport_release(struct vsock_sock *vsk)
712 {
713 	struct sock *sk = &vsk->sk;
714 	bool remove_sock = true;
715 
716 	lock_sock(sk);
717 	if (sk->sk_type == SOCK_STREAM)
718 		remove_sock = virtio_transport_close(vsk);
719 	release_sock(sk);
720 
721 	if (remove_sock)
722 		vsock_remove_sock(vsk);
723 }
724 EXPORT_SYMBOL_GPL(virtio_transport_release);
725 
726 static int
727 virtio_transport_recv_connecting(struct sock *sk,
728 				 struct virtio_vsock_pkt *pkt)
729 {
730 	struct vsock_sock *vsk = vsock_sk(sk);
731 	int err;
732 	int skerr;
733 
734 	switch (le16_to_cpu(pkt->hdr.op)) {
735 	case VIRTIO_VSOCK_OP_RESPONSE:
736 		sk->sk_state = SS_CONNECTED;
737 		sk->sk_socket->state = SS_CONNECTED;
738 		vsock_insert_connected(vsk);
739 		sk->sk_state_change(sk);
740 		break;
741 	case VIRTIO_VSOCK_OP_INVALID:
742 		break;
743 	case VIRTIO_VSOCK_OP_RST:
744 		skerr = ECONNRESET;
745 		err = 0;
746 		goto destroy;
747 	default:
748 		skerr = EPROTO;
749 		err = -EINVAL;
750 		goto destroy;
751 	}
752 	return 0;
753 
754 destroy:
755 	virtio_transport_reset(vsk, pkt);
756 	sk->sk_state = SS_UNCONNECTED;
757 	sk->sk_err = skerr;
758 	sk->sk_error_report(sk);
759 	return err;
760 }
761 
762 static int
763 virtio_transport_recv_connected(struct sock *sk,
764 				struct virtio_vsock_pkt *pkt)
765 {
766 	struct vsock_sock *vsk = vsock_sk(sk);
767 	struct virtio_vsock_sock *vvs = vsk->trans;
768 	int err = 0;
769 
770 	switch (le16_to_cpu(pkt->hdr.op)) {
771 	case VIRTIO_VSOCK_OP_RW:
772 		pkt->len = le32_to_cpu(pkt->hdr.len);
773 		pkt->off = 0;
774 
775 		spin_lock_bh(&vvs->rx_lock);
776 		virtio_transport_inc_rx_pkt(vvs, pkt);
777 		list_add_tail(&pkt->list, &vvs->rx_queue);
778 		spin_unlock_bh(&vvs->rx_lock);
779 
780 		sk->sk_data_ready(sk);
781 		return err;
782 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
783 		sk->sk_write_space(sk);
784 		break;
785 	case VIRTIO_VSOCK_OP_SHUTDOWN:
786 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
787 			vsk->peer_shutdown |= RCV_SHUTDOWN;
788 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
789 			vsk->peer_shutdown |= SEND_SHUTDOWN;
790 		if (vsk->peer_shutdown == SHUTDOWN_MASK &&
791 		    vsock_stream_has_data(vsk) <= 0)
792 			sk->sk_state = SS_DISCONNECTING;
793 		if (le32_to_cpu(pkt->hdr.flags))
794 			sk->sk_state_change(sk);
795 		break;
796 	case VIRTIO_VSOCK_OP_RST:
797 		virtio_transport_do_close(vsk, true);
798 		break;
799 	default:
800 		err = -EINVAL;
801 		break;
802 	}
803 
804 	virtio_transport_free_pkt(pkt);
805 	return err;
806 }
807 
808 static void
809 virtio_transport_recv_disconnecting(struct sock *sk,
810 				    struct virtio_vsock_pkt *pkt)
811 {
812 	struct vsock_sock *vsk = vsock_sk(sk);
813 
814 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
815 		virtio_transport_do_close(vsk, true);
816 }
817 
818 static int
819 virtio_transport_send_response(struct vsock_sock *vsk,
820 			       struct virtio_vsock_pkt *pkt)
821 {
822 	struct virtio_vsock_pkt_info info = {
823 		.op = VIRTIO_VSOCK_OP_RESPONSE,
824 		.type = VIRTIO_VSOCK_TYPE_STREAM,
825 		.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
826 		.remote_port = le32_to_cpu(pkt->hdr.src_port),
827 		.reply = true,
828 	};
829 
830 	return virtio_transport_send_pkt_info(vsk, &info);
831 }
832 
833 /* Handle server socket */
834 static int
835 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
836 {
837 	struct vsock_sock *vsk = vsock_sk(sk);
838 	struct vsock_sock *vchild;
839 	struct sock *child;
840 
841 	if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
842 		virtio_transport_reset(vsk, pkt);
843 		return -EINVAL;
844 	}
845 
846 	if (sk_acceptq_is_full(sk)) {
847 		virtio_transport_reset(vsk, pkt);
848 		return -ENOMEM;
849 	}
850 
851 	child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
852 			       sk->sk_type, 0);
853 	if (!child) {
854 		virtio_transport_reset(vsk, pkt);
855 		return -ENOMEM;
856 	}
857 
858 	sk->sk_ack_backlog++;
859 
860 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
861 
862 	child->sk_state = SS_CONNECTED;
863 
864 	vchild = vsock_sk(child);
865 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
866 			le32_to_cpu(pkt->hdr.dst_port));
867 	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
868 			le32_to_cpu(pkt->hdr.src_port));
869 
870 	vsock_insert_connected(vchild);
871 	vsock_enqueue_accept(sk, child);
872 	virtio_transport_send_response(vchild, pkt);
873 
874 	release_sock(child);
875 
876 	sk->sk_data_ready(sk);
877 	return 0;
878 }
879 
880 static bool virtio_transport_space_update(struct sock *sk,
881 					  struct virtio_vsock_pkt *pkt)
882 {
883 	struct vsock_sock *vsk = vsock_sk(sk);
884 	struct virtio_vsock_sock *vvs = vsk->trans;
885 	bool space_available;
886 
887 	/* buf_alloc and fwd_cnt is always included in the hdr */
888 	spin_lock_bh(&vvs->tx_lock);
889 	vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
890 	vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
891 	space_available = virtio_transport_has_space(vsk);
892 	spin_unlock_bh(&vvs->tx_lock);
893 	return space_available;
894 }
895 
896 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
897  * lock.
898  */
899 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
900 {
901 	struct sockaddr_vm src, dst;
902 	struct vsock_sock *vsk;
903 	struct sock *sk;
904 	bool space_available;
905 
906 	vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
907 			le32_to_cpu(pkt->hdr.src_port));
908 	vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
909 			le32_to_cpu(pkt->hdr.dst_port));
910 
911 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
912 					dst.svm_cid, dst.svm_port,
913 					le32_to_cpu(pkt->hdr.len),
914 					le16_to_cpu(pkt->hdr.type),
915 					le16_to_cpu(pkt->hdr.op),
916 					le32_to_cpu(pkt->hdr.flags),
917 					le32_to_cpu(pkt->hdr.buf_alloc),
918 					le32_to_cpu(pkt->hdr.fwd_cnt));
919 
920 	if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
921 		(void)virtio_transport_reset_no_sock(pkt);
922 		goto free_pkt;
923 	}
924 
925 	/* The socket must be in connected or bound table
926 	 * otherwise send reset back
927 	 */
928 	sk = vsock_find_connected_socket(&src, &dst);
929 	if (!sk) {
930 		sk = vsock_find_bound_socket(&dst);
931 		if (!sk) {
932 			(void)virtio_transport_reset_no_sock(pkt);
933 			goto free_pkt;
934 		}
935 	}
936 
937 	vsk = vsock_sk(sk);
938 
939 	space_available = virtio_transport_space_update(sk, pkt);
940 
941 	lock_sock(sk);
942 
943 	/* Update CID in case it has changed after a transport reset event */
944 	vsk->local_addr.svm_cid = dst.svm_cid;
945 
946 	if (space_available)
947 		sk->sk_write_space(sk);
948 
949 	switch (sk->sk_state) {
950 	case VSOCK_SS_LISTEN:
951 		virtio_transport_recv_listen(sk, pkt);
952 		virtio_transport_free_pkt(pkt);
953 		break;
954 	case SS_CONNECTING:
955 		virtio_transport_recv_connecting(sk, pkt);
956 		virtio_transport_free_pkt(pkt);
957 		break;
958 	case SS_CONNECTED:
959 		virtio_transport_recv_connected(sk, pkt);
960 		break;
961 	case SS_DISCONNECTING:
962 		virtio_transport_recv_disconnecting(sk, pkt);
963 		virtio_transport_free_pkt(pkt);
964 		break;
965 	default:
966 		virtio_transport_free_pkt(pkt);
967 		break;
968 	}
969 	release_sock(sk);
970 
971 	/* Release refcnt obtained when we fetched this socket out of the
972 	 * bound or connected list.
973 	 */
974 	sock_put(sk);
975 	return;
976 
977 free_pkt:
978 	virtio_transport_free_pkt(pkt);
979 }
980 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
981 
982 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
983 {
984 	kfree(pkt->buf);
985 	kfree(pkt);
986 }
987 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
988 
989 MODULE_LICENSE("GPL v2");
990 MODULE_AUTHOR("Asias He");
991 MODULE_DESCRIPTION("common code for virtio vsock");
992