1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * common code for virtio vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18 #include <uapi/linux/vsockmon.h>
19 
20 #include <net/sock.h>
21 #include <net/af_vsock.h>
22 
23 #define CREATE_TRACE_POINTS
24 #include <trace/events/vsock_virtio_transport_common.h>
25 
26 /* How long to wait for graceful shutdown of a connection */
27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
28 
29 /* Threshold for detecting small packets to copy */
30 #define GOOD_COPY_LEN  128
31 
32 static const struct virtio_transport *virtio_transport_get_ops(void)
33 {
34 	const struct vsock_transport *t = vsock_core_get_transport();
35 
36 	return container_of(t, struct virtio_transport, transport);
37 }
38 
39 static struct virtio_vsock_pkt *
40 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
41 			   size_t len,
42 			   u32 src_cid,
43 			   u32 src_port,
44 			   u32 dst_cid,
45 			   u32 dst_port)
46 {
47 	struct virtio_vsock_pkt *pkt;
48 	int err;
49 
50 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
51 	if (!pkt)
52 		return NULL;
53 
54 	pkt->hdr.type		= cpu_to_le16(info->type);
55 	pkt->hdr.op		= cpu_to_le16(info->op);
56 	pkt->hdr.src_cid	= cpu_to_le64(src_cid);
57 	pkt->hdr.dst_cid	= cpu_to_le64(dst_cid);
58 	pkt->hdr.src_port	= cpu_to_le32(src_port);
59 	pkt->hdr.dst_port	= cpu_to_le32(dst_port);
60 	pkt->hdr.flags		= cpu_to_le32(info->flags);
61 	pkt->len		= len;
62 	pkt->hdr.len		= cpu_to_le32(len);
63 	pkt->reply		= info->reply;
64 	pkt->vsk		= info->vsk;
65 
66 	if (info->msg && len > 0) {
67 		pkt->buf = kmalloc(len, GFP_KERNEL);
68 		if (!pkt->buf)
69 			goto out_pkt;
70 
71 		pkt->buf_len = len;
72 
73 		err = memcpy_from_msg(pkt->buf, info->msg, len);
74 		if (err)
75 			goto out;
76 	}
77 
78 	trace_virtio_transport_alloc_pkt(src_cid, src_port,
79 					 dst_cid, dst_port,
80 					 len,
81 					 info->type,
82 					 info->op,
83 					 info->flags);
84 
85 	return pkt;
86 
87 out:
88 	kfree(pkt->buf);
89 out_pkt:
90 	kfree(pkt);
91 	return NULL;
92 }
93 
94 /* Packet capture */
95 static struct sk_buff *virtio_transport_build_skb(void *opaque)
96 {
97 	struct virtio_vsock_pkt *pkt = opaque;
98 	struct af_vsockmon_hdr *hdr;
99 	struct sk_buff *skb;
100 	size_t payload_len;
101 	void *payload_buf;
102 
103 	/* A packet could be split to fit the RX buffer, so we can retrieve
104 	 * the payload length from the header and the buffer pointer taking
105 	 * care of the offset in the original packet.
106 	 */
107 	payload_len = le32_to_cpu(pkt->hdr.len);
108 	payload_buf = pkt->buf + pkt->off;
109 
110 	skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
111 			GFP_ATOMIC);
112 	if (!skb)
113 		return NULL;
114 
115 	hdr = skb_put(skb, sizeof(*hdr));
116 
117 	/* pkt->hdr is little-endian so no need to byteswap here */
118 	hdr->src_cid = pkt->hdr.src_cid;
119 	hdr->src_port = pkt->hdr.src_port;
120 	hdr->dst_cid = pkt->hdr.dst_cid;
121 	hdr->dst_port = pkt->hdr.dst_port;
122 
123 	hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
124 	hdr->len = cpu_to_le16(sizeof(pkt->hdr));
125 	memset(hdr->reserved, 0, sizeof(hdr->reserved));
126 
127 	switch (le16_to_cpu(pkt->hdr.op)) {
128 	case VIRTIO_VSOCK_OP_REQUEST:
129 	case VIRTIO_VSOCK_OP_RESPONSE:
130 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
131 		break;
132 	case VIRTIO_VSOCK_OP_RST:
133 	case VIRTIO_VSOCK_OP_SHUTDOWN:
134 		hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
135 		break;
136 	case VIRTIO_VSOCK_OP_RW:
137 		hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
138 		break;
139 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
140 	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
141 		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
142 		break;
143 	default:
144 		hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
145 		break;
146 	}
147 
148 	skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
149 
150 	if (payload_len) {
151 		skb_put_data(skb, payload_buf, payload_len);
152 	}
153 
154 	return skb;
155 }
156 
157 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
158 {
159 	vsock_deliver_tap(virtio_transport_build_skb, pkt);
160 }
161 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
162 
163 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
164 					  struct virtio_vsock_pkt_info *info)
165 {
166 	u32 src_cid, src_port, dst_cid, dst_port;
167 	struct virtio_vsock_sock *vvs;
168 	struct virtio_vsock_pkt *pkt;
169 	u32 pkt_len = info->pkt_len;
170 
171 	src_cid = vm_sockets_get_local_cid();
172 	src_port = vsk->local_addr.svm_port;
173 	if (!info->remote_cid) {
174 		dst_cid	= vsk->remote_addr.svm_cid;
175 		dst_port = vsk->remote_addr.svm_port;
176 	} else {
177 		dst_cid = info->remote_cid;
178 		dst_port = info->remote_port;
179 	}
180 
181 	vvs = vsk->trans;
182 
183 	/* we can send less than pkt_len bytes */
184 	if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
185 		pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
186 
187 	/* virtio_transport_get_credit might return less than pkt_len credit */
188 	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
189 
190 	/* Do not send zero length OP_RW pkt */
191 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
192 		return pkt_len;
193 
194 	pkt = virtio_transport_alloc_pkt(info, pkt_len,
195 					 src_cid, src_port,
196 					 dst_cid, dst_port);
197 	if (!pkt) {
198 		virtio_transport_put_credit(vvs, pkt_len);
199 		return -ENOMEM;
200 	}
201 
202 	virtio_transport_inc_tx_pkt(vvs, pkt);
203 
204 	return virtio_transport_get_ops()->send_pkt(pkt);
205 }
206 
207 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
208 					struct virtio_vsock_pkt *pkt)
209 {
210 	vvs->rx_bytes += pkt->len;
211 }
212 
213 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
214 					struct virtio_vsock_pkt *pkt)
215 {
216 	vvs->rx_bytes -= pkt->len;
217 	vvs->fwd_cnt += pkt->len;
218 }
219 
220 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
221 {
222 	spin_lock_bh(&vvs->rx_lock);
223 	vvs->last_fwd_cnt = vvs->fwd_cnt;
224 	pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
225 	pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
226 	spin_unlock_bh(&vvs->rx_lock);
227 }
228 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
229 
230 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
231 {
232 	u32 ret;
233 
234 	spin_lock_bh(&vvs->tx_lock);
235 	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
236 	if (ret > credit)
237 		ret = credit;
238 	vvs->tx_cnt += ret;
239 	spin_unlock_bh(&vvs->tx_lock);
240 
241 	return ret;
242 }
243 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
244 
245 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
246 {
247 	spin_lock_bh(&vvs->tx_lock);
248 	vvs->tx_cnt -= credit;
249 	spin_unlock_bh(&vvs->tx_lock);
250 }
251 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
252 
253 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
254 					       int type,
255 					       struct virtio_vsock_hdr *hdr)
256 {
257 	struct virtio_vsock_pkt_info info = {
258 		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
259 		.type = type,
260 		.vsk = vsk,
261 	};
262 
263 	return virtio_transport_send_pkt_info(vsk, &info);
264 }
265 
266 static ssize_t
267 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
268 				   struct msghdr *msg,
269 				   size_t len)
270 {
271 	struct virtio_vsock_sock *vvs = vsk->trans;
272 	struct virtio_vsock_pkt *pkt;
273 	size_t bytes, total = 0;
274 	u32 free_space;
275 	int err = -EFAULT;
276 
277 	spin_lock_bh(&vvs->rx_lock);
278 	while (total < len && !list_empty(&vvs->rx_queue)) {
279 		pkt = list_first_entry(&vvs->rx_queue,
280 				       struct virtio_vsock_pkt, list);
281 
282 		bytes = len - total;
283 		if (bytes > pkt->len - pkt->off)
284 			bytes = pkt->len - pkt->off;
285 
286 		/* sk_lock is held by caller so no one else can dequeue.
287 		 * Unlock rx_lock since memcpy_to_msg() may sleep.
288 		 */
289 		spin_unlock_bh(&vvs->rx_lock);
290 
291 		err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
292 		if (err)
293 			goto out;
294 
295 		spin_lock_bh(&vvs->rx_lock);
296 
297 		total += bytes;
298 		pkt->off += bytes;
299 		if (pkt->off == pkt->len) {
300 			virtio_transport_dec_rx_pkt(vvs, pkt);
301 			list_del(&pkt->list);
302 			virtio_transport_free_pkt(pkt);
303 		}
304 	}
305 
306 	free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
307 
308 	spin_unlock_bh(&vvs->rx_lock);
309 
310 	/* To reduce the number of credit update messages,
311 	 * don't update credits as long as lots of space is available.
312 	 * Note: the limit chosen here is arbitrary. Setting the limit
313 	 * too high causes extra messages. Too low causes transmitter
314 	 * stalls. As stalls are in theory more expensive than extra
315 	 * messages, we set the limit to a high value. TODO: experiment
316 	 * with different values.
317 	 */
318 	if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
319 		virtio_transport_send_credit_update(vsk,
320 						    VIRTIO_VSOCK_TYPE_STREAM,
321 						    NULL);
322 	}
323 
324 	return total;
325 
326 out:
327 	if (total)
328 		err = total;
329 	return err;
330 }
331 
332 ssize_t
333 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
334 				struct msghdr *msg,
335 				size_t len, int flags)
336 {
337 	if (flags & MSG_PEEK)
338 		return -EOPNOTSUPP;
339 
340 	return virtio_transport_stream_do_dequeue(vsk, msg, len);
341 }
342 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
343 
344 int
345 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
346 			       struct msghdr *msg,
347 			       size_t len, int flags)
348 {
349 	return -EOPNOTSUPP;
350 }
351 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
352 
353 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
354 {
355 	struct virtio_vsock_sock *vvs = vsk->trans;
356 	s64 bytes;
357 
358 	spin_lock_bh(&vvs->rx_lock);
359 	bytes = vvs->rx_bytes;
360 	spin_unlock_bh(&vvs->rx_lock);
361 
362 	return bytes;
363 }
364 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
365 
366 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
367 {
368 	struct virtio_vsock_sock *vvs = vsk->trans;
369 	s64 bytes;
370 
371 	bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
372 	if (bytes < 0)
373 		bytes = 0;
374 
375 	return bytes;
376 }
377 
378 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
379 {
380 	struct virtio_vsock_sock *vvs = vsk->trans;
381 	s64 bytes;
382 
383 	spin_lock_bh(&vvs->tx_lock);
384 	bytes = virtio_transport_has_space(vsk);
385 	spin_unlock_bh(&vvs->tx_lock);
386 
387 	return bytes;
388 }
389 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
390 
391 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
392 				    struct vsock_sock *psk)
393 {
394 	struct virtio_vsock_sock *vvs;
395 
396 	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
397 	if (!vvs)
398 		return -ENOMEM;
399 
400 	vsk->trans = vvs;
401 	vvs->vsk = vsk;
402 	if (psk) {
403 		struct virtio_vsock_sock *ptrans = psk->trans;
404 
405 		vvs->buf_size	= ptrans->buf_size;
406 		vvs->buf_size_min = ptrans->buf_size_min;
407 		vvs->buf_size_max = ptrans->buf_size_max;
408 		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
409 	} else {
410 		vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
411 		vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
412 		vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
413 	}
414 
415 	vvs->buf_alloc = vvs->buf_size;
416 
417 	spin_lock_init(&vvs->rx_lock);
418 	spin_lock_init(&vvs->tx_lock);
419 	INIT_LIST_HEAD(&vvs->rx_queue);
420 
421 	return 0;
422 }
423 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
424 
425 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
426 {
427 	struct virtio_vsock_sock *vvs = vsk->trans;
428 
429 	return vvs->buf_size;
430 }
431 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
432 
433 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
434 {
435 	struct virtio_vsock_sock *vvs = vsk->trans;
436 
437 	return vvs->buf_size_min;
438 }
439 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
440 
441 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
442 {
443 	struct virtio_vsock_sock *vvs = vsk->trans;
444 
445 	return vvs->buf_size_max;
446 }
447 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
448 
449 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
450 {
451 	struct virtio_vsock_sock *vvs = vsk->trans;
452 
453 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
454 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
455 	if (val < vvs->buf_size_min)
456 		vvs->buf_size_min = val;
457 	if (val > vvs->buf_size_max)
458 		vvs->buf_size_max = val;
459 	vvs->buf_size = val;
460 	vvs->buf_alloc = val;
461 }
462 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
463 
464 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
465 {
466 	struct virtio_vsock_sock *vvs = vsk->trans;
467 
468 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
469 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
470 	if (val > vvs->buf_size)
471 		vvs->buf_size = val;
472 	vvs->buf_size_min = val;
473 }
474 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
475 
476 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
477 {
478 	struct virtio_vsock_sock *vvs = vsk->trans;
479 
480 	if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
481 		val = VIRTIO_VSOCK_MAX_BUF_SIZE;
482 	if (val < vvs->buf_size)
483 		vvs->buf_size = val;
484 	vvs->buf_size_max = val;
485 }
486 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
487 
488 int
489 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
490 				size_t target,
491 				bool *data_ready_now)
492 {
493 	if (vsock_stream_has_data(vsk))
494 		*data_ready_now = true;
495 	else
496 		*data_ready_now = false;
497 
498 	return 0;
499 }
500 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
501 
502 int
503 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
504 				 size_t target,
505 				 bool *space_avail_now)
506 {
507 	s64 free_space;
508 
509 	free_space = vsock_stream_has_space(vsk);
510 	if (free_space > 0)
511 		*space_avail_now = true;
512 	else if (free_space == 0)
513 		*space_avail_now = false;
514 
515 	return 0;
516 }
517 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
518 
519 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
520 	size_t target, struct vsock_transport_recv_notify_data *data)
521 {
522 	return 0;
523 }
524 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
525 
526 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
527 	size_t target, struct vsock_transport_recv_notify_data *data)
528 {
529 	return 0;
530 }
531 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
532 
533 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
534 	size_t target, struct vsock_transport_recv_notify_data *data)
535 {
536 	return 0;
537 }
538 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
539 
540 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
541 	size_t target, ssize_t copied, bool data_read,
542 	struct vsock_transport_recv_notify_data *data)
543 {
544 	return 0;
545 }
546 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
547 
548 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
549 	struct vsock_transport_send_notify_data *data)
550 {
551 	return 0;
552 }
553 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
554 
555 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
556 	struct vsock_transport_send_notify_data *data)
557 {
558 	return 0;
559 }
560 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
561 
562 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
563 	struct vsock_transport_send_notify_data *data)
564 {
565 	return 0;
566 }
567 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
568 
569 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
570 	ssize_t written, struct vsock_transport_send_notify_data *data)
571 {
572 	return 0;
573 }
574 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
575 
576 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
577 {
578 	struct virtio_vsock_sock *vvs = vsk->trans;
579 
580 	return vvs->buf_size;
581 }
582 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
583 
584 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
585 {
586 	return true;
587 }
588 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
589 
590 bool virtio_transport_stream_allow(u32 cid, u32 port)
591 {
592 	return true;
593 }
594 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
595 
596 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
597 				struct sockaddr_vm *addr)
598 {
599 	return -EOPNOTSUPP;
600 }
601 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
602 
603 bool virtio_transport_dgram_allow(u32 cid, u32 port)
604 {
605 	return false;
606 }
607 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
608 
609 int virtio_transport_connect(struct vsock_sock *vsk)
610 {
611 	struct virtio_vsock_pkt_info info = {
612 		.op = VIRTIO_VSOCK_OP_REQUEST,
613 		.type = VIRTIO_VSOCK_TYPE_STREAM,
614 		.vsk = vsk,
615 	};
616 
617 	return virtio_transport_send_pkt_info(vsk, &info);
618 }
619 EXPORT_SYMBOL_GPL(virtio_transport_connect);
620 
621 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
622 {
623 	struct virtio_vsock_pkt_info info = {
624 		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
625 		.type = VIRTIO_VSOCK_TYPE_STREAM,
626 		.flags = (mode & RCV_SHUTDOWN ?
627 			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
628 			 (mode & SEND_SHUTDOWN ?
629 			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
630 		.vsk = vsk,
631 	};
632 
633 	return virtio_transport_send_pkt_info(vsk, &info);
634 }
635 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
636 
637 int
638 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
639 			       struct sockaddr_vm *remote_addr,
640 			       struct msghdr *msg,
641 			       size_t dgram_len)
642 {
643 	return -EOPNOTSUPP;
644 }
645 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
646 
647 ssize_t
648 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
649 				struct msghdr *msg,
650 				size_t len)
651 {
652 	struct virtio_vsock_pkt_info info = {
653 		.op = VIRTIO_VSOCK_OP_RW,
654 		.type = VIRTIO_VSOCK_TYPE_STREAM,
655 		.msg = msg,
656 		.pkt_len = len,
657 		.vsk = vsk,
658 	};
659 
660 	return virtio_transport_send_pkt_info(vsk, &info);
661 }
662 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
663 
664 void virtio_transport_destruct(struct vsock_sock *vsk)
665 {
666 	struct virtio_vsock_sock *vvs = vsk->trans;
667 
668 	kfree(vvs);
669 }
670 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
671 
672 static int virtio_transport_reset(struct vsock_sock *vsk,
673 				  struct virtio_vsock_pkt *pkt)
674 {
675 	struct virtio_vsock_pkt_info info = {
676 		.op = VIRTIO_VSOCK_OP_RST,
677 		.type = VIRTIO_VSOCK_TYPE_STREAM,
678 		.reply = !!pkt,
679 		.vsk = vsk,
680 	};
681 
682 	/* Send RST only if the original pkt is not a RST pkt */
683 	if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
684 		return 0;
685 
686 	return virtio_transport_send_pkt_info(vsk, &info);
687 }
688 
689 /* Normally packets are associated with a socket.  There may be no socket if an
690  * attempt was made to connect to a socket that does not exist.
691  */
692 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
693 {
694 	const struct virtio_transport *t;
695 	struct virtio_vsock_pkt *reply;
696 	struct virtio_vsock_pkt_info info = {
697 		.op = VIRTIO_VSOCK_OP_RST,
698 		.type = le16_to_cpu(pkt->hdr.type),
699 		.reply = true,
700 	};
701 
702 	/* Send RST only if the original pkt is not a RST pkt */
703 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
704 		return 0;
705 
706 	reply = virtio_transport_alloc_pkt(&info, 0,
707 					   le64_to_cpu(pkt->hdr.dst_cid),
708 					   le32_to_cpu(pkt->hdr.dst_port),
709 					   le64_to_cpu(pkt->hdr.src_cid),
710 					   le32_to_cpu(pkt->hdr.src_port));
711 	if (!reply)
712 		return -ENOMEM;
713 
714 	t = virtio_transport_get_ops();
715 	if (!t) {
716 		virtio_transport_free_pkt(reply);
717 		return -ENOTCONN;
718 	}
719 
720 	return t->send_pkt(reply);
721 }
722 
723 static void virtio_transport_wait_close(struct sock *sk, long timeout)
724 {
725 	if (timeout) {
726 		DEFINE_WAIT_FUNC(wait, woken_wake_function);
727 
728 		add_wait_queue(sk_sleep(sk), &wait);
729 
730 		do {
731 			if (sk_wait_event(sk, &timeout,
732 					  sock_flag(sk, SOCK_DONE), &wait))
733 				break;
734 		} while (!signal_pending(current) && timeout);
735 
736 		remove_wait_queue(sk_sleep(sk), &wait);
737 	}
738 }
739 
740 static void virtio_transport_do_close(struct vsock_sock *vsk,
741 				      bool cancel_timeout)
742 {
743 	struct sock *sk = sk_vsock(vsk);
744 
745 	sock_set_flag(sk, SOCK_DONE);
746 	vsk->peer_shutdown = SHUTDOWN_MASK;
747 	if (vsock_stream_has_data(vsk) <= 0)
748 		sk->sk_state = TCP_CLOSING;
749 	sk->sk_state_change(sk);
750 
751 	if (vsk->close_work_scheduled &&
752 	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
753 		vsk->close_work_scheduled = false;
754 
755 		vsock_remove_sock(vsk);
756 
757 		/* Release refcnt obtained when we scheduled the timeout */
758 		sock_put(sk);
759 	}
760 }
761 
762 static void virtio_transport_close_timeout(struct work_struct *work)
763 {
764 	struct vsock_sock *vsk =
765 		container_of(work, struct vsock_sock, close_work.work);
766 	struct sock *sk = sk_vsock(vsk);
767 
768 	sock_hold(sk);
769 	lock_sock(sk);
770 
771 	if (!sock_flag(sk, SOCK_DONE)) {
772 		(void)virtio_transport_reset(vsk, NULL);
773 
774 		virtio_transport_do_close(vsk, false);
775 	}
776 
777 	vsk->close_work_scheduled = false;
778 
779 	release_sock(sk);
780 	sock_put(sk);
781 }
782 
783 /* User context, vsk->sk is locked */
784 static bool virtio_transport_close(struct vsock_sock *vsk)
785 {
786 	struct sock *sk = &vsk->sk;
787 
788 	if (!(sk->sk_state == TCP_ESTABLISHED ||
789 	      sk->sk_state == TCP_CLOSING))
790 		return true;
791 
792 	/* Already received SHUTDOWN from peer, reply with RST */
793 	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
794 		(void)virtio_transport_reset(vsk, NULL);
795 		return true;
796 	}
797 
798 	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
799 		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
800 
801 	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
802 		virtio_transport_wait_close(sk, sk->sk_lingertime);
803 
804 	if (sock_flag(sk, SOCK_DONE)) {
805 		return true;
806 	}
807 
808 	sock_hold(sk);
809 	INIT_DELAYED_WORK(&vsk->close_work,
810 			  virtio_transport_close_timeout);
811 	vsk->close_work_scheduled = true;
812 	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
813 	return false;
814 }
815 
816 void virtio_transport_release(struct vsock_sock *vsk)
817 {
818 	struct virtio_vsock_sock *vvs = vsk->trans;
819 	struct virtio_vsock_pkt *pkt, *tmp;
820 	struct sock *sk = &vsk->sk;
821 	bool remove_sock = true;
822 
823 	lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
824 	if (sk->sk_type == SOCK_STREAM)
825 		remove_sock = virtio_transport_close(vsk);
826 
827 	list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
828 		list_del(&pkt->list);
829 		virtio_transport_free_pkt(pkt);
830 	}
831 	release_sock(sk);
832 
833 	if (remove_sock)
834 		vsock_remove_sock(vsk);
835 }
836 EXPORT_SYMBOL_GPL(virtio_transport_release);
837 
838 static int
839 virtio_transport_recv_connecting(struct sock *sk,
840 				 struct virtio_vsock_pkt *pkt)
841 {
842 	struct vsock_sock *vsk = vsock_sk(sk);
843 	int err;
844 	int skerr;
845 
846 	switch (le16_to_cpu(pkt->hdr.op)) {
847 	case VIRTIO_VSOCK_OP_RESPONSE:
848 		sk->sk_state = TCP_ESTABLISHED;
849 		sk->sk_socket->state = SS_CONNECTED;
850 		vsock_insert_connected(vsk);
851 		sk->sk_state_change(sk);
852 		break;
853 	case VIRTIO_VSOCK_OP_INVALID:
854 		break;
855 	case VIRTIO_VSOCK_OP_RST:
856 		skerr = ECONNRESET;
857 		err = 0;
858 		goto destroy;
859 	default:
860 		skerr = EPROTO;
861 		err = -EINVAL;
862 		goto destroy;
863 	}
864 	return 0;
865 
866 destroy:
867 	virtio_transport_reset(vsk, pkt);
868 	sk->sk_state = TCP_CLOSE;
869 	sk->sk_err = skerr;
870 	sk->sk_error_report(sk);
871 	return err;
872 }
873 
874 static void
875 virtio_transport_recv_enqueue(struct vsock_sock *vsk,
876 			      struct virtio_vsock_pkt *pkt)
877 {
878 	struct virtio_vsock_sock *vvs = vsk->trans;
879 	bool free_pkt = false;
880 
881 	pkt->len = le32_to_cpu(pkt->hdr.len);
882 	pkt->off = 0;
883 
884 	spin_lock_bh(&vvs->rx_lock);
885 
886 	virtio_transport_inc_rx_pkt(vvs, pkt);
887 
888 	/* Try to copy small packets into the buffer of last packet queued,
889 	 * to avoid wasting memory queueing the entire buffer with a small
890 	 * payload.
891 	 */
892 	if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
893 		struct virtio_vsock_pkt *last_pkt;
894 
895 		last_pkt = list_last_entry(&vvs->rx_queue,
896 					   struct virtio_vsock_pkt, list);
897 
898 		/* If there is space in the last packet queued, we copy the
899 		 * new packet in its buffer.
900 		 */
901 		if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
902 			memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
903 			       pkt->len);
904 			last_pkt->len += pkt->len;
905 			free_pkt = true;
906 			goto out;
907 		}
908 	}
909 
910 	list_add_tail(&pkt->list, &vvs->rx_queue);
911 
912 out:
913 	spin_unlock_bh(&vvs->rx_lock);
914 	if (free_pkt)
915 		virtio_transport_free_pkt(pkt);
916 }
917 
918 static int
919 virtio_transport_recv_connected(struct sock *sk,
920 				struct virtio_vsock_pkt *pkt)
921 {
922 	struct vsock_sock *vsk = vsock_sk(sk);
923 	int err = 0;
924 
925 	switch (le16_to_cpu(pkt->hdr.op)) {
926 	case VIRTIO_VSOCK_OP_RW:
927 		virtio_transport_recv_enqueue(vsk, pkt);
928 		sk->sk_data_ready(sk);
929 		return err;
930 	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
931 		sk->sk_write_space(sk);
932 		break;
933 	case VIRTIO_VSOCK_OP_SHUTDOWN:
934 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
935 			vsk->peer_shutdown |= RCV_SHUTDOWN;
936 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
937 			vsk->peer_shutdown |= SEND_SHUTDOWN;
938 		if (vsk->peer_shutdown == SHUTDOWN_MASK &&
939 		    vsock_stream_has_data(vsk) <= 0) {
940 			sock_set_flag(sk, SOCK_DONE);
941 			sk->sk_state = TCP_CLOSING;
942 		}
943 		if (le32_to_cpu(pkt->hdr.flags))
944 			sk->sk_state_change(sk);
945 		break;
946 	case VIRTIO_VSOCK_OP_RST:
947 		virtio_transport_do_close(vsk, true);
948 		break;
949 	default:
950 		err = -EINVAL;
951 		break;
952 	}
953 
954 	virtio_transport_free_pkt(pkt);
955 	return err;
956 }
957 
958 static void
959 virtio_transport_recv_disconnecting(struct sock *sk,
960 				    struct virtio_vsock_pkt *pkt)
961 {
962 	struct vsock_sock *vsk = vsock_sk(sk);
963 
964 	if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
965 		virtio_transport_do_close(vsk, true);
966 }
967 
968 static int
969 virtio_transport_send_response(struct vsock_sock *vsk,
970 			       struct virtio_vsock_pkt *pkt)
971 {
972 	struct virtio_vsock_pkt_info info = {
973 		.op = VIRTIO_VSOCK_OP_RESPONSE,
974 		.type = VIRTIO_VSOCK_TYPE_STREAM,
975 		.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
976 		.remote_port = le32_to_cpu(pkt->hdr.src_port),
977 		.reply = true,
978 		.vsk = vsk,
979 	};
980 
981 	return virtio_transport_send_pkt_info(vsk, &info);
982 }
983 
984 /* Handle server socket */
985 static int
986 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
987 {
988 	struct vsock_sock *vsk = vsock_sk(sk);
989 	struct vsock_sock *vchild;
990 	struct sock *child;
991 
992 	if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
993 		virtio_transport_reset(vsk, pkt);
994 		return -EINVAL;
995 	}
996 
997 	if (sk_acceptq_is_full(sk)) {
998 		virtio_transport_reset(vsk, pkt);
999 		return -ENOMEM;
1000 	}
1001 
1002 	child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
1003 			       sk->sk_type, 0);
1004 	if (!child) {
1005 		virtio_transport_reset(vsk, pkt);
1006 		return -ENOMEM;
1007 	}
1008 
1009 	sk->sk_ack_backlog++;
1010 
1011 	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1012 
1013 	child->sk_state = TCP_ESTABLISHED;
1014 
1015 	vchild = vsock_sk(child);
1016 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
1017 			le32_to_cpu(pkt->hdr.dst_port));
1018 	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
1019 			le32_to_cpu(pkt->hdr.src_port));
1020 
1021 	vsock_insert_connected(vchild);
1022 	vsock_enqueue_accept(sk, child);
1023 	virtio_transport_send_response(vchild, pkt);
1024 
1025 	release_sock(child);
1026 
1027 	sk->sk_data_ready(sk);
1028 	return 0;
1029 }
1030 
1031 static bool virtio_transport_space_update(struct sock *sk,
1032 					  struct virtio_vsock_pkt *pkt)
1033 {
1034 	struct vsock_sock *vsk = vsock_sk(sk);
1035 	struct virtio_vsock_sock *vvs = vsk->trans;
1036 	bool space_available;
1037 
1038 	/* buf_alloc and fwd_cnt is always included in the hdr */
1039 	spin_lock_bh(&vvs->tx_lock);
1040 	vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
1041 	vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
1042 	space_available = virtio_transport_has_space(vsk);
1043 	spin_unlock_bh(&vvs->tx_lock);
1044 	return space_available;
1045 }
1046 
1047 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1048  * lock.
1049  */
1050 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
1051 {
1052 	struct sockaddr_vm src, dst;
1053 	struct vsock_sock *vsk;
1054 	struct sock *sk;
1055 	bool space_available;
1056 
1057 	vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
1058 			le32_to_cpu(pkt->hdr.src_port));
1059 	vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
1060 			le32_to_cpu(pkt->hdr.dst_port));
1061 
1062 	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1063 					dst.svm_cid, dst.svm_port,
1064 					le32_to_cpu(pkt->hdr.len),
1065 					le16_to_cpu(pkt->hdr.type),
1066 					le16_to_cpu(pkt->hdr.op),
1067 					le32_to_cpu(pkt->hdr.flags),
1068 					le32_to_cpu(pkt->hdr.buf_alloc),
1069 					le32_to_cpu(pkt->hdr.fwd_cnt));
1070 
1071 	if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
1072 		(void)virtio_transport_reset_no_sock(pkt);
1073 		goto free_pkt;
1074 	}
1075 
1076 	/* The socket must be in connected or bound table
1077 	 * otherwise send reset back
1078 	 */
1079 	sk = vsock_find_connected_socket(&src, &dst);
1080 	if (!sk) {
1081 		sk = vsock_find_bound_socket(&dst);
1082 		if (!sk) {
1083 			(void)virtio_transport_reset_no_sock(pkt);
1084 			goto free_pkt;
1085 		}
1086 	}
1087 
1088 	vsk = vsock_sk(sk);
1089 
1090 	space_available = virtio_transport_space_update(sk, pkt);
1091 
1092 	lock_sock(sk);
1093 
1094 	/* Update CID in case it has changed after a transport reset event */
1095 	vsk->local_addr.svm_cid = dst.svm_cid;
1096 
1097 	if (space_available)
1098 		sk->sk_write_space(sk);
1099 
1100 	switch (sk->sk_state) {
1101 	case TCP_LISTEN:
1102 		virtio_transport_recv_listen(sk, pkt);
1103 		virtio_transport_free_pkt(pkt);
1104 		break;
1105 	case TCP_SYN_SENT:
1106 		virtio_transport_recv_connecting(sk, pkt);
1107 		virtio_transport_free_pkt(pkt);
1108 		break;
1109 	case TCP_ESTABLISHED:
1110 		virtio_transport_recv_connected(sk, pkt);
1111 		break;
1112 	case TCP_CLOSING:
1113 		virtio_transport_recv_disconnecting(sk, pkt);
1114 		virtio_transport_free_pkt(pkt);
1115 		break;
1116 	default:
1117 		virtio_transport_free_pkt(pkt);
1118 		break;
1119 	}
1120 	release_sock(sk);
1121 
1122 	/* Release refcnt obtained when we fetched this socket out of the
1123 	 * bound or connected list.
1124 	 */
1125 	sock_put(sk);
1126 	return;
1127 
1128 free_pkt:
1129 	virtio_transport_free_pkt(pkt);
1130 }
1131 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1132 
1133 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1134 {
1135 	kfree(pkt->buf);
1136 	kfree(pkt);
1137 }
1138 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1139 
1140 MODULE_LICENSE("GPL v2");
1141 MODULE_AUTHOR("Asias He");
1142 MODULE_DESCRIPTION("common code for virtio vsock");
1143