xref: /openbmc/linux/drivers/vhost/net.c (revision 2eb0f624b709e78ec8e2f4c3412947703db99301)
1 /* Copyright (C) 2009 Red Hat, Inc.
2  * Author: Michael S. Tsirkin <mst@redhat.com>
3  *
4  * This work is licensed under the terms of the GNU GPL, version 2.
5  *
6  * virtio-net server in host kernel.
7  */
8 
9 #include <linux/compat.h>
10 #include <linux/eventfd.h>
11 #include <linux/vhost.h>
12 #include <linux/virtio_net.h>
13 #include <linux/miscdevice.h>
14 #include <linux/module.h>
15 #include <linux/moduleparam.h>
16 #include <linux/mutex.h>
17 #include <linux/workqueue.h>
18 #include <linux/file.h>
19 #include <linux/slab.h>
20 #include <linux/sched/clock.h>
21 #include <linux/sched/signal.h>
22 #include <linux/vmalloc.h>
23 
24 #include <linux/net.h>
25 #include <linux/if_packet.h>
26 #include <linux/if_arp.h>
27 #include <linux/if_tun.h>
28 #include <linux/if_macvlan.h>
29 #include <linux/if_tap.h>
30 #include <linux/if_vlan.h>
31 #include <linux/skb_array.h>
32 #include <linux/skbuff.h>
33 
34 #include <net/sock.h>
35 #include <net/xdp.h>
36 
37 #include "vhost.h"
38 
39 static int experimental_zcopytx = 1;
40 module_param(experimental_zcopytx, int, 0444);
41 MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
42 		                       " 1 -Enable; 0 - Disable");
43 
44 /* Max number of bytes transferred before requeueing the job.
45  * Using this limit prevents one virtqueue from starving others. */
46 #define VHOST_NET_WEIGHT 0x80000
47 
48 /* Max number of packets transferred before requeueing the job.
49  * Using this limit prevents one virtqueue from starving rx. */
50 #define VHOST_NET_PKT_WEIGHT(vq) ((vq)->num * 2)
51 
52 /* MAX number of TX used buffers for outstanding zerocopy */
53 #define VHOST_MAX_PEND 128
54 #define VHOST_GOODCOPY_LEN 256
55 
56 /*
57  * For transmit, used buffer len is unused; we override it to track buffer
58  * status internally; used for zerocopy tx only.
59  */
60 /* Lower device DMA failed */
61 #define VHOST_DMA_FAILED_LEN	((__force __virtio32)3)
62 /* Lower device DMA done */
63 #define VHOST_DMA_DONE_LEN	((__force __virtio32)2)
64 /* Lower device DMA in progress */
65 #define VHOST_DMA_IN_PROGRESS	((__force __virtio32)1)
66 /* Buffer unused */
67 #define VHOST_DMA_CLEAR_LEN	((__force __virtio32)0)
68 
69 #define VHOST_DMA_IS_DONE(len) ((__force u32)(len) >= (__force u32)VHOST_DMA_DONE_LEN)
70 
71 enum {
72 	VHOST_NET_FEATURES = VHOST_FEATURES |
73 			 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
74 			 (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
75 			 (1ULL << VIRTIO_F_IOMMU_PLATFORM)
76 };
77 
78 enum {
79 	VHOST_NET_VQ_RX = 0,
80 	VHOST_NET_VQ_TX = 1,
81 	VHOST_NET_VQ_MAX = 2,
82 };
83 
84 struct vhost_net_ubuf_ref {
85 	/* refcount follows semantics similar to kref:
86 	 *  0: object is released
87 	 *  1: no outstanding ubufs
88 	 * >1: outstanding ubufs
89 	 */
90 	atomic_t refcount;
91 	wait_queue_head_t wait;
92 	struct vhost_virtqueue *vq;
93 };
94 
95 #define VHOST_RX_BATCH 64
96 struct vhost_net_buf {
97 	void **queue;
98 	int tail;
99 	int head;
100 };
101 
102 struct vhost_net_virtqueue {
103 	struct vhost_virtqueue vq;
104 	size_t vhost_hlen;
105 	size_t sock_hlen;
106 	/* vhost zerocopy support fields below: */
107 	/* last used idx for outstanding DMA zerocopy buffers */
108 	int upend_idx;
109 	/* first used idx for DMA done zerocopy buffers */
110 	int done_idx;
111 	/* an array of userspace buffers info */
112 	struct ubuf_info *ubuf_info;
113 	/* Reference counting for outstanding ubufs.
114 	 * Protected by vq mutex. Writers must also take device mutex. */
115 	struct vhost_net_ubuf_ref *ubufs;
116 	struct ptr_ring *rx_ring;
117 	struct vhost_net_buf rxq;
118 };
119 
120 struct vhost_net {
121 	struct vhost_dev dev;
122 	struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
123 	struct vhost_poll poll[VHOST_NET_VQ_MAX];
124 	/* Number of TX recently submitted.
125 	 * Protected by tx vq lock. */
126 	unsigned tx_packets;
127 	/* Number of times zerocopy TX recently failed.
128 	 * Protected by tx vq lock. */
129 	unsigned tx_zcopy_err;
130 	/* Flush in progress. Protected by tx vq lock. */
131 	bool tx_flush;
132 };
133 
134 static unsigned vhost_net_zcopy_mask __read_mostly;
135 
136 static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq)
137 {
138 	if (rxq->tail != rxq->head)
139 		return rxq->queue[rxq->head];
140 	else
141 		return NULL;
142 }
143 
144 static int vhost_net_buf_get_size(struct vhost_net_buf *rxq)
145 {
146 	return rxq->tail - rxq->head;
147 }
148 
149 static int vhost_net_buf_is_empty(struct vhost_net_buf *rxq)
150 {
151 	return rxq->tail == rxq->head;
152 }
153 
154 static void *vhost_net_buf_consume(struct vhost_net_buf *rxq)
155 {
156 	void *ret = vhost_net_buf_get_ptr(rxq);
157 	++rxq->head;
158 	return ret;
159 }
160 
161 static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq)
162 {
163 	struct vhost_net_buf *rxq = &nvq->rxq;
164 
165 	rxq->head = 0;
166 	rxq->tail = ptr_ring_consume_batched(nvq->rx_ring, rxq->queue,
167 					      VHOST_RX_BATCH);
168 	return rxq->tail;
169 }
170 
171 static void vhost_net_buf_unproduce(struct vhost_net_virtqueue *nvq)
172 {
173 	struct vhost_net_buf *rxq = &nvq->rxq;
174 
175 	if (nvq->rx_ring && !vhost_net_buf_is_empty(rxq)) {
176 		ptr_ring_unconsume(nvq->rx_ring, rxq->queue + rxq->head,
177 				   vhost_net_buf_get_size(rxq),
178 				   tun_ptr_free);
179 		rxq->head = rxq->tail = 0;
180 	}
181 }
182 
183 static int vhost_net_buf_peek_len(void *ptr)
184 {
185 	if (tun_is_xdp_frame(ptr)) {
186 		struct xdp_frame *xdpf = tun_ptr_to_xdp(ptr);
187 
188 		return xdpf->len;
189 	}
190 
191 	return __skb_array_len_with_tag(ptr);
192 }
193 
194 static int vhost_net_buf_peek(struct vhost_net_virtqueue *nvq)
195 {
196 	struct vhost_net_buf *rxq = &nvq->rxq;
197 
198 	if (!vhost_net_buf_is_empty(rxq))
199 		goto out;
200 
201 	if (!vhost_net_buf_produce(nvq))
202 		return 0;
203 
204 out:
205 	return vhost_net_buf_peek_len(vhost_net_buf_get_ptr(rxq));
206 }
207 
208 static void vhost_net_buf_init(struct vhost_net_buf *rxq)
209 {
210 	rxq->head = rxq->tail = 0;
211 }
212 
213 static void vhost_net_enable_zcopy(int vq)
214 {
215 	vhost_net_zcopy_mask |= 0x1 << vq;
216 }
217 
218 static struct vhost_net_ubuf_ref *
219 vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
220 {
221 	struct vhost_net_ubuf_ref *ubufs;
222 	/* No zero copy backend? Nothing to count. */
223 	if (!zcopy)
224 		return NULL;
225 	ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
226 	if (!ubufs)
227 		return ERR_PTR(-ENOMEM);
228 	atomic_set(&ubufs->refcount, 1);
229 	init_waitqueue_head(&ubufs->wait);
230 	ubufs->vq = vq;
231 	return ubufs;
232 }
233 
234 static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
235 {
236 	int r = atomic_sub_return(1, &ubufs->refcount);
237 	if (unlikely(!r))
238 		wake_up(&ubufs->wait);
239 	return r;
240 }
241 
242 static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
243 {
244 	vhost_net_ubuf_put(ubufs);
245 	wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
246 }
247 
248 static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
249 {
250 	vhost_net_ubuf_put_and_wait(ubufs);
251 	kfree(ubufs);
252 }
253 
254 static void vhost_net_clear_ubuf_info(struct vhost_net *n)
255 {
256 	int i;
257 
258 	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
259 		kfree(n->vqs[i].ubuf_info);
260 		n->vqs[i].ubuf_info = NULL;
261 	}
262 }
263 
264 static int vhost_net_set_ubuf_info(struct vhost_net *n)
265 {
266 	bool zcopy;
267 	int i;
268 
269 	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
270 		zcopy = vhost_net_zcopy_mask & (0x1 << i);
271 		if (!zcopy)
272 			continue;
273 		n->vqs[i].ubuf_info = kmalloc(sizeof(*n->vqs[i].ubuf_info) *
274 					      UIO_MAXIOV, GFP_KERNEL);
275 		if  (!n->vqs[i].ubuf_info)
276 			goto err;
277 	}
278 	return 0;
279 
280 err:
281 	vhost_net_clear_ubuf_info(n);
282 	return -ENOMEM;
283 }
284 
285 static void vhost_net_vq_reset(struct vhost_net *n)
286 {
287 	int i;
288 
289 	vhost_net_clear_ubuf_info(n);
290 
291 	for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
292 		n->vqs[i].done_idx = 0;
293 		n->vqs[i].upend_idx = 0;
294 		n->vqs[i].ubufs = NULL;
295 		n->vqs[i].vhost_hlen = 0;
296 		n->vqs[i].sock_hlen = 0;
297 		vhost_net_buf_init(&n->vqs[i].rxq);
298 	}
299 
300 }
301 
302 static void vhost_net_tx_packet(struct vhost_net *net)
303 {
304 	++net->tx_packets;
305 	if (net->tx_packets < 1024)
306 		return;
307 	net->tx_packets = 0;
308 	net->tx_zcopy_err = 0;
309 }
310 
311 static void vhost_net_tx_err(struct vhost_net *net)
312 {
313 	++net->tx_zcopy_err;
314 }
315 
316 static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
317 {
318 	/* TX flush waits for outstanding DMAs to be done.
319 	 * Don't start new DMAs.
320 	 */
321 	return !net->tx_flush &&
322 		net->tx_packets / 64 >= net->tx_zcopy_err;
323 }
324 
325 static bool vhost_sock_zcopy(struct socket *sock)
326 {
327 	return unlikely(experimental_zcopytx) &&
328 		sock_flag(sock->sk, SOCK_ZEROCOPY);
329 }
330 
331 /* In case of DMA done not in order in lower device driver for some reason.
332  * upend_idx is used to track end of used idx, done_idx is used to track head
333  * of used idx. Once lower device DMA done contiguously, we will signal KVM
334  * guest used idx.
335  */
336 static void vhost_zerocopy_signal_used(struct vhost_net *net,
337 				       struct vhost_virtqueue *vq)
338 {
339 	struct vhost_net_virtqueue *nvq =
340 		container_of(vq, struct vhost_net_virtqueue, vq);
341 	int i, add;
342 	int j = 0;
343 
344 	for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
345 		if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
346 			vhost_net_tx_err(net);
347 		if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
348 			vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
349 			++j;
350 		} else
351 			break;
352 	}
353 	while (j) {
354 		add = min(UIO_MAXIOV - nvq->done_idx, j);
355 		vhost_add_used_and_signal_n(vq->dev, vq,
356 					    &vq->heads[nvq->done_idx], add);
357 		nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV;
358 		j -= add;
359 	}
360 }
361 
362 static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
363 {
364 	struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
365 	struct vhost_virtqueue *vq = ubufs->vq;
366 	int cnt;
367 
368 	rcu_read_lock_bh();
369 
370 	/* set len to mark this desc buffers done DMA */
371 	vq->heads[ubuf->desc].len = success ?
372 		VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
373 	cnt = vhost_net_ubuf_put(ubufs);
374 
375 	/*
376 	 * Trigger polling thread if guest stopped submitting new buffers:
377 	 * in this case, the refcount after decrement will eventually reach 1.
378 	 * We also trigger polling periodically after each 16 packets
379 	 * (the value 16 here is more or less arbitrary, it's tuned to trigger
380 	 * less than 10% of times).
381 	 */
382 	if (cnt <= 1 || !(cnt % 16))
383 		vhost_poll_queue(&vq->poll);
384 
385 	rcu_read_unlock_bh();
386 }
387 
388 static inline unsigned long busy_clock(void)
389 {
390 	return local_clock() >> 10;
391 }
392 
393 static bool vhost_can_busy_poll(struct vhost_dev *dev,
394 				unsigned long endtime)
395 {
396 	return likely(!need_resched()) &&
397 	       likely(!time_after(busy_clock(), endtime)) &&
398 	       likely(!signal_pending(current)) &&
399 	       !vhost_has_work(dev);
400 }
401 
402 static void vhost_net_disable_vq(struct vhost_net *n,
403 				 struct vhost_virtqueue *vq)
404 {
405 	struct vhost_net_virtqueue *nvq =
406 		container_of(vq, struct vhost_net_virtqueue, vq);
407 	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
408 	if (!vq->private_data)
409 		return;
410 	vhost_poll_stop(poll);
411 }
412 
413 static int vhost_net_enable_vq(struct vhost_net *n,
414 				struct vhost_virtqueue *vq)
415 {
416 	struct vhost_net_virtqueue *nvq =
417 		container_of(vq, struct vhost_net_virtqueue, vq);
418 	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
419 	struct socket *sock;
420 
421 	sock = vq->private_data;
422 	if (!sock)
423 		return 0;
424 
425 	return vhost_poll_start(poll, sock->file);
426 }
427 
428 static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
429 				    struct vhost_virtqueue *vq,
430 				    struct iovec iov[], unsigned int iov_size,
431 				    unsigned int *out_num, unsigned int *in_num)
432 {
433 	unsigned long uninitialized_var(endtime);
434 	int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
435 				  out_num, in_num, NULL, NULL);
436 
437 	if (r == vq->num && vq->busyloop_timeout) {
438 		preempt_disable();
439 		endtime = busy_clock() + vq->busyloop_timeout;
440 		while (vhost_can_busy_poll(vq->dev, endtime) &&
441 		       vhost_vq_avail_empty(vq->dev, vq))
442 			cpu_relax();
443 		preempt_enable();
444 		r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
445 				      out_num, in_num, NULL, NULL);
446 	}
447 
448 	return r;
449 }
450 
451 static bool vhost_exceeds_maxpend(struct vhost_net *net)
452 {
453 	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
454 	struct vhost_virtqueue *vq = &nvq->vq;
455 
456 	return (nvq->upend_idx + UIO_MAXIOV - nvq->done_idx) % UIO_MAXIOV >
457 	       min_t(unsigned int, VHOST_MAX_PEND, vq->num >> 2);
458 }
459 
460 /* Expects to be always run from workqueue - which acts as
461  * read-size critical section for our kind of RCU. */
462 static void handle_tx(struct vhost_net *net)
463 {
464 	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
465 	struct vhost_virtqueue *vq = &nvq->vq;
466 	unsigned out, in;
467 	int head;
468 	struct msghdr msg = {
469 		.msg_name = NULL,
470 		.msg_namelen = 0,
471 		.msg_control = NULL,
472 		.msg_controllen = 0,
473 		.msg_flags = MSG_DONTWAIT,
474 	};
475 	size_t len, total_len = 0;
476 	int err;
477 	size_t hdr_size;
478 	struct socket *sock;
479 	struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
480 	bool zcopy, zcopy_used;
481 	int sent_pkts = 0;
482 
483 	mutex_lock(&vq->mutex);
484 	sock = vq->private_data;
485 	if (!sock)
486 		goto out;
487 
488 	if (!vq_iotlb_prefetch(vq))
489 		goto out;
490 
491 	vhost_disable_notify(&net->dev, vq);
492 	vhost_net_disable_vq(net, vq);
493 
494 	hdr_size = nvq->vhost_hlen;
495 	zcopy = nvq->ubufs;
496 
497 	for (;;) {
498 		/* Release DMAs done buffers first */
499 		if (zcopy)
500 			vhost_zerocopy_signal_used(net, vq);
501 
502 
503 		head = vhost_net_tx_get_vq_desc(net, vq, vq->iov,
504 						ARRAY_SIZE(vq->iov),
505 						&out, &in);
506 		/* On error, stop handling until the next kick. */
507 		if (unlikely(head < 0))
508 			break;
509 		/* Nothing new?  Wait for eventfd to tell us they refilled. */
510 		if (head == vq->num) {
511 			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
512 				vhost_disable_notify(&net->dev, vq);
513 				continue;
514 			}
515 			break;
516 		}
517 		if (in) {
518 			vq_err(vq, "Unexpected descriptor format for TX: "
519 			       "out %d, int %d\n", out, in);
520 			break;
521 		}
522 		/* Skip header. TODO: support TSO. */
523 		len = iov_length(vq->iov, out);
524 		iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len);
525 		iov_iter_advance(&msg.msg_iter, hdr_size);
526 		/* Sanity check */
527 		if (!msg_data_left(&msg)) {
528 			vq_err(vq, "Unexpected header len for TX: "
529 			       "%zd expected %zd\n",
530 			       len, hdr_size);
531 			break;
532 		}
533 		len = msg_data_left(&msg);
534 
535 		zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
536 				   && !vhost_exceeds_maxpend(net)
537 				   && vhost_net_tx_select_zcopy(net);
538 
539 		/* use msg_control to pass vhost zerocopy ubuf info to skb */
540 		if (zcopy_used) {
541 			struct ubuf_info *ubuf;
542 			ubuf = nvq->ubuf_info + nvq->upend_idx;
543 
544 			vq->heads[nvq->upend_idx].id = cpu_to_vhost32(vq, head);
545 			vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
546 			ubuf->callback = vhost_zerocopy_callback;
547 			ubuf->ctx = nvq->ubufs;
548 			ubuf->desc = nvq->upend_idx;
549 			refcount_set(&ubuf->refcnt, 1);
550 			msg.msg_control = ubuf;
551 			msg.msg_controllen = sizeof(ubuf);
552 			ubufs = nvq->ubufs;
553 			atomic_inc(&ubufs->refcount);
554 			nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
555 		} else {
556 			msg.msg_control = NULL;
557 			ubufs = NULL;
558 		}
559 
560 		total_len += len;
561 		if (total_len < VHOST_NET_WEIGHT &&
562 		    !vhost_vq_avail_empty(&net->dev, vq) &&
563 		    likely(!vhost_exceeds_maxpend(net))) {
564 			msg.msg_flags |= MSG_MORE;
565 		} else {
566 			msg.msg_flags &= ~MSG_MORE;
567 		}
568 
569 		/* TODO: Check specific error and bomb out unless ENOBUFS? */
570 		err = sock->ops->sendmsg(sock, &msg, len);
571 		if (unlikely(err < 0)) {
572 			if (zcopy_used) {
573 				vhost_net_ubuf_put(ubufs);
574 				nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
575 					% UIO_MAXIOV;
576 			}
577 			vhost_discard_vq_desc(vq, 1);
578 			vhost_net_enable_vq(net, vq);
579 			break;
580 		}
581 		if (err != len)
582 			pr_debug("Truncated TX packet: "
583 				 " len %d != %zd\n", err, len);
584 		if (!zcopy_used)
585 			vhost_add_used_and_signal(&net->dev, vq, head, 0);
586 		else
587 			vhost_zerocopy_signal_used(net, vq);
588 		vhost_net_tx_packet(net);
589 		if (unlikely(total_len >= VHOST_NET_WEIGHT) ||
590 		    unlikely(++sent_pkts >= VHOST_NET_PKT_WEIGHT(vq))) {
591 			vhost_poll_queue(&vq->poll);
592 			break;
593 		}
594 	}
595 out:
596 	mutex_unlock(&vq->mutex);
597 }
598 
599 static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk)
600 {
601 	struct sk_buff *head;
602 	int len = 0;
603 	unsigned long flags;
604 
605 	if (rvq->rx_ring)
606 		return vhost_net_buf_peek(rvq);
607 
608 	spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
609 	head = skb_peek(&sk->sk_receive_queue);
610 	if (likely(head)) {
611 		len = head->len;
612 		if (skb_vlan_tag_present(head))
613 			len += VLAN_HLEN;
614 	}
615 
616 	spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
617 	return len;
618 }
619 
620 static int sk_has_rx_data(struct sock *sk)
621 {
622 	struct socket *sock = sk->sk_socket;
623 
624 	if (sock->ops->peek_len)
625 		return sock->ops->peek_len(sock);
626 
627 	return skb_queue_empty(&sk->sk_receive_queue);
628 }
629 
630 static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
631 {
632 	struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX];
633 	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
634 	struct vhost_virtqueue *vq = &nvq->vq;
635 	unsigned long uninitialized_var(endtime);
636 	int len = peek_head_len(rvq, sk);
637 
638 	if (!len && vq->busyloop_timeout) {
639 		/* Both tx vq and rx socket were polled here */
640 		mutex_lock_nested(&vq->mutex, 1);
641 		vhost_disable_notify(&net->dev, vq);
642 
643 		preempt_disable();
644 		endtime = busy_clock() + vq->busyloop_timeout;
645 
646 		while (vhost_can_busy_poll(&net->dev, endtime) &&
647 		       !sk_has_rx_data(sk) &&
648 		       vhost_vq_avail_empty(&net->dev, vq))
649 			cpu_relax();
650 
651 		preempt_enable();
652 
653 		if (!vhost_vq_avail_empty(&net->dev, vq))
654 			vhost_poll_queue(&vq->poll);
655 		else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
656 			vhost_disable_notify(&net->dev, vq);
657 			vhost_poll_queue(&vq->poll);
658 		}
659 
660 		mutex_unlock(&vq->mutex);
661 
662 		len = peek_head_len(rvq, sk);
663 	}
664 
665 	return len;
666 }
667 
668 /* This is a multi-buffer version of vhost_get_desc, that works if
669  *	vq has read descriptors only.
670  * @vq		- the relevant virtqueue
671  * @datalen	- data length we'll be reading
672  * @iovcount	- returned count of io vectors we fill
673  * @log		- vhost log
674  * @log_num	- log offset
675  * @quota       - headcount quota, 1 for big buffer
676  *	returns number of buffer heads allocated, negative on error
677  */
678 static int get_rx_bufs(struct vhost_virtqueue *vq,
679 		       struct vring_used_elem *heads,
680 		       int datalen,
681 		       unsigned *iovcount,
682 		       struct vhost_log *log,
683 		       unsigned *log_num,
684 		       unsigned int quota)
685 {
686 	unsigned int out, in;
687 	int seg = 0;
688 	int headcount = 0;
689 	unsigned d;
690 	int r, nlogs = 0;
691 	/* len is always initialized before use since we are always called with
692 	 * datalen > 0.
693 	 */
694 	u32 uninitialized_var(len);
695 
696 	while (datalen > 0 && headcount < quota) {
697 		if (unlikely(seg >= UIO_MAXIOV)) {
698 			r = -ENOBUFS;
699 			goto err;
700 		}
701 		r = vhost_get_vq_desc(vq, vq->iov + seg,
702 				      ARRAY_SIZE(vq->iov) - seg, &out,
703 				      &in, log, log_num);
704 		if (unlikely(r < 0))
705 			goto err;
706 
707 		d = r;
708 		if (d == vq->num) {
709 			r = 0;
710 			goto err;
711 		}
712 		if (unlikely(out || in <= 0)) {
713 			vq_err(vq, "unexpected descriptor format for RX: "
714 				"out %d, in %d\n", out, in);
715 			r = -EINVAL;
716 			goto err;
717 		}
718 		if (unlikely(log)) {
719 			nlogs += *log_num;
720 			log += *log_num;
721 		}
722 		heads[headcount].id = cpu_to_vhost32(vq, d);
723 		len = iov_length(vq->iov + seg, in);
724 		heads[headcount].len = cpu_to_vhost32(vq, len);
725 		datalen -= len;
726 		++headcount;
727 		seg += in;
728 	}
729 	heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen);
730 	*iovcount = seg;
731 	if (unlikely(log))
732 		*log_num = nlogs;
733 
734 	/* Detect overrun */
735 	if (unlikely(datalen > 0)) {
736 		r = UIO_MAXIOV + 1;
737 		goto err;
738 	}
739 	return headcount;
740 err:
741 	vhost_discard_vq_desc(vq, headcount);
742 	return r;
743 }
744 
745 /* Expects to be always run from workqueue - which acts as
746  * read-size critical section for our kind of RCU. */
747 static void handle_rx(struct vhost_net *net)
748 {
749 	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
750 	struct vhost_virtqueue *vq = &nvq->vq;
751 	unsigned uninitialized_var(in), log;
752 	struct vhost_log *vq_log;
753 	struct msghdr msg = {
754 		.msg_name = NULL,
755 		.msg_namelen = 0,
756 		.msg_control = NULL, /* FIXME: get and handle RX aux data. */
757 		.msg_controllen = 0,
758 		.msg_flags = MSG_DONTWAIT,
759 	};
760 	struct virtio_net_hdr hdr = {
761 		.flags = 0,
762 		.gso_type = VIRTIO_NET_HDR_GSO_NONE
763 	};
764 	size_t total_len = 0;
765 	int err, mergeable;
766 	s16 headcount, nheads = 0;
767 	size_t vhost_hlen, sock_hlen;
768 	size_t vhost_len, sock_len;
769 	struct socket *sock;
770 	struct iov_iter fixup;
771 	__virtio16 num_buffers;
772 
773 	mutex_lock_nested(&vq->mutex, 0);
774 	sock = vq->private_data;
775 	if (!sock)
776 		goto out;
777 
778 	if (!vq_iotlb_prefetch(vq))
779 		goto out;
780 
781 	vhost_disable_notify(&net->dev, vq);
782 	vhost_net_disable_vq(net, vq);
783 
784 	vhost_hlen = nvq->vhost_hlen;
785 	sock_hlen = nvq->sock_hlen;
786 
787 	vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
788 		vq->log : NULL;
789 	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
790 
791 	while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
792 		sock_len += sock_hlen;
793 		vhost_len = sock_len + vhost_hlen;
794 		headcount = get_rx_bufs(vq, vq->heads + nheads, vhost_len,
795 					&in, vq_log, &log,
796 					likely(mergeable) ? UIO_MAXIOV : 1);
797 		/* On error, stop handling until the next kick. */
798 		if (unlikely(headcount < 0))
799 			goto out;
800 		/* OK, now we need to know about added descriptors. */
801 		if (!headcount) {
802 			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
803 				/* They have slipped one in as we were
804 				 * doing that: check again. */
805 				vhost_disable_notify(&net->dev, vq);
806 				continue;
807 			}
808 			/* Nothing new?  Wait for eventfd to tell us
809 			 * they refilled. */
810 			goto out;
811 		}
812 		if (nvq->rx_ring)
813 			msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
814 		/* On overrun, truncate and discard */
815 		if (unlikely(headcount > UIO_MAXIOV)) {
816 			iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
817 			err = sock->ops->recvmsg(sock, &msg,
818 						 1, MSG_DONTWAIT | MSG_TRUNC);
819 			pr_debug("Discarded rx packet: len %zd\n", sock_len);
820 			continue;
821 		}
822 		/* We don't need to be notified again. */
823 		iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
824 		fixup = msg.msg_iter;
825 		if (unlikely((vhost_hlen))) {
826 			/* We will supply the header ourselves
827 			 * TODO: support TSO.
828 			 */
829 			iov_iter_advance(&msg.msg_iter, vhost_hlen);
830 		}
831 		err = sock->ops->recvmsg(sock, &msg,
832 					 sock_len, MSG_DONTWAIT | MSG_TRUNC);
833 		/* Userspace might have consumed the packet meanwhile:
834 		 * it's not supposed to do this usually, but might be hard
835 		 * to prevent. Discard data we got (if any) and keep going. */
836 		if (unlikely(err != sock_len)) {
837 			pr_debug("Discarded rx packet: "
838 				 " len %d, expected %zd\n", err, sock_len);
839 			vhost_discard_vq_desc(vq, headcount);
840 			continue;
841 		}
842 		/* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */
843 		if (unlikely(vhost_hlen)) {
844 			if (copy_to_iter(&hdr, sizeof(hdr),
845 					 &fixup) != sizeof(hdr)) {
846 				vq_err(vq, "Unable to write vnet_hdr "
847 				       "at addr %p\n", vq->iov->iov_base);
848 				goto out;
849 			}
850 		} else {
851 			/* Header came from socket; we'll need to patch
852 			 * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF
853 			 */
854 			iov_iter_advance(&fixup, sizeof(hdr));
855 		}
856 		/* TODO: Should check and handle checksum. */
857 
858 		num_buffers = cpu_to_vhost16(vq, headcount);
859 		if (likely(mergeable) &&
860 		    copy_to_iter(&num_buffers, sizeof num_buffers,
861 				 &fixup) != sizeof num_buffers) {
862 			vq_err(vq, "Failed num_buffers write");
863 			vhost_discard_vq_desc(vq, headcount);
864 			goto out;
865 		}
866 		nheads += headcount;
867 		if (nheads > VHOST_RX_BATCH) {
868 			vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
869 						    nheads);
870 			nheads = 0;
871 		}
872 		if (unlikely(vq_log))
873 			vhost_log_write(vq, vq_log, log, vhost_len);
874 		total_len += vhost_len;
875 		if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
876 			vhost_poll_queue(&vq->poll);
877 			goto out;
878 		}
879 	}
880 	vhost_net_enable_vq(net, vq);
881 out:
882 	if (nheads)
883 		vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
884 					    nheads);
885 	mutex_unlock(&vq->mutex);
886 }
887 
888 static void handle_tx_kick(struct vhost_work *work)
889 {
890 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
891 						  poll.work);
892 	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
893 
894 	handle_tx(net);
895 }
896 
897 static void handle_rx_kick(struct vhost_work *work)
898 {
899 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
900 						  poll.work);
901 	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
902 
903 	handle_rx(net);
904 }
905 
906 static void handle_tx_net(struct vhost_work *work)
907 {
908 	struct vhost_net *net = container_of(work, struct vhost_net,
909 					     poll[VHOST_NET_VQ_TX].work);
910 	handle_tx(net);
911 }
912 
913 static void handle_rx_net(struct vhost_work *work)
914 {
915 	struct vhost_net *net = container_of(work, struct vhost_net,
916 					     poll[VHOST_NET_VQ_RX].work);
917 	handle_rx(net);
918 }
919 
920 static int vhost_net_open(struct inode *inode, struct file *f)
921 {
922 	struct vhost_net *n;
923 	struct vhost_dev *dev;
924 	struct vhost_virtqueue **vqs;
925 	void **queue;
926 	int i;
927 
928 	n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL);
929 	if (!n)
930 		return -ENOMEM;
931 	vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
932 	if (!vqs) {
933 		kvfree(n);
934 		return -ENOMEM;
935 	}
936 
937 	queue = kmalloc_array(VHOST_RX_BATCH, sizeof(void *),
938 			      GFP_KERNEL);
939 	if (!queue) {
940 		kfree(vqs);
941 		kvfree(n);
942 		return -ENOMEM;
943 	}
944 	n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue;
945 
946 	dev = &n->dev;
947 	vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
948 	vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
949 	n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
950 	n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
951 	for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
952 		n->vqs[i].ubufs = NULL;
953 		n->vqs[i].ubuf_info = NULL;
954 		n->vqs[i].upend_idx = 0;
955 		n->vqs[i].done_idx = 0;
956 		n->vqs[i].vhost_hlen = 0;
957 		n->vqs[i].sock_hlen = 0;
958 		n->vqs[i].rx_ring = NULL;
959 		vhost_net_buf_init(&n->vqs[i].rxq);
960 	}
961 	vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
962 
963 	vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
964 	vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
965 
966 	f->private_data = n;
967 
968 	return 0;
969 }
970 
971 static struct socket *vhost_net_stop_vq(struct vhost_net *n,
972 					struct vhost_virtqueue *vq)
973 {
974 	struct socket *sock;
975 	struct vhost_net_virtqueue *nvq =
976 		container_of(vq, struct vhost_net_virtqueue, vq);
977 
978 	mutex_lock(&vq->mutex);
979 	sock = vq->private_data;
980 	vhost_net_disable_vq(n, vq);
981 	vq->private_data = NULL;
982 	vhost_net_buf_unproduce(nvq);
983 	nvq->rx_ring = NULL;
984 	mutex_unlock(&vq->mutex);
985 	return sock;
986 }
987 
988 static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
989 			   struct socket **rx_sock)
990 {
991 	*tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
992 	*rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
993 }
994 
995 static void vhost_net_flush_vq(struct vhost_net *n, int index)
996 {
997 	vhost_poll_flush(n->poll + index);
998 	vhost_poll_flush(&n->vqs[index].vq.poll);
999 }
1000 
1001 static void vhost_net_flush(struct vhost_net *n)
1002 {
1003 	vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
1004 	vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
1005 	if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
1006 		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1007 		n->tx_flush = true;
1008 		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1009 		/* Wait for all lower device DMAs done. */
1010 		vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
1011 		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1012 		n->tx_flush = false;
1013 		atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
1014 		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1015 	}
1016 }
1017 
1018 static int vhost_net_release(struct inode *inode, struct file *f)
1019 {
1020 	struct vhost_net *n = f->private_data;
1021 	struct socket *tx_sock;
1022 	struct socket *rx_sock;
1023 
1024 	vhost_net_stop(n, &tx_sock, &rx_sock);
1025 	vhost_net_flush(n);
1026 	vhost_dev_stop(&n->dev);
1027 	vhost_dev_cleanup(&n->dev);
1028 	vhost_net_vq_reset(n);
1029 	if (tx_sock)
1030 		sockfd_put(tx_sock);
1031 	if (rx_sock)
1032 		sockfd_put(rx_sock);
1033 	/* Make sure no callbacks are outstanding */
1034 	synchronize_rcu_bh();
1035 	/* We do an extra flush before freeing memory,
1036 	 * since jobs can re-queue themselves. */
1037 	vhost_net_flush(n);
1038 	kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
1039 	kfree(n->dev.vqs);
1040 	kvfree(n);
1041 	return 0;
1042 }
1043 
1044 static struct socket *get_raw_socket(int fd)
1045 {
1046 	struct {
1047 		struct sockaddr_ll sa;
1048 		char  buf[MAX_ADDR_LEN];
1049 	} uaddr;
1050 	int r;
1051 	struct socket *sock = sockfd_lookup(fd, &r);
1052 
1053 	if (!sock)
1054 		return ERR_PTR(-ENOTSOCK);
1055 
1056 	/* Parameter checking */
1057 	if (sock->sk->sk_type != SOCK_RAW) {
1058 		r = -ESOCKTNOSUPPORT;
1059 		goto err;
1060 	}
1061 
1062 	r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa, 0);
1063 	if (r < 0)
1064 		goto err;
1065 
1066 	if (uaddr.sa.sll_family != AF_PACKET) {
1067 		r = -EPFNOSUPPORT;
1068 		goto err;
1069 	}
1070 	return sock;
1071 err:
1072 	sockfd_put(sock);
1073 	return ERR_PTR(r);
1074 }
1075 
1076 static struct ptr_ring *get_tap_ptr_ring(int fd)
1077 {
1078 	struct ptr_ring *ring;
1079 	struct file *file = fget(fd);
1080 
1081 	if (!file)
1082 		return NULL;
1083 	ring = tun_get_tx_ring(file);
1084 	if (!IS_ERR(ring))
1085 		goto out;
1086 	ring = tap_get_ptr_ring(file);
1087 	if (!IS_ERR(ring))
1088 		goto out;
1089 	ring = NULL;
1090 out:
1091 	fput(file);
1092 	return ring;
1093 }
1094 
1095 static struct socket *get_tap_socket(int fd)
1096 {
1097 	struct file *file = fget(fd);
1098 	struct socket *sock;
1099 
1100 	if (!file)
1101 		return ERR_PTR(-EBADF);
1102 	sock = tun_get_socket(file);
1103 	if (!IS_ERR(sock))
1104 		return sock;
1105 	sock = tap_get_socket(file);
1106 	if (IS_ERR(sock))
1107 		fput(file);
1108 	return sock;
1109 }
1110 
1111 static struct socket *get_socket(int fd)
1112 {
1113 	struct socket *sock;
1114 
1115 	/* special case to disable backend */
1116 	if (fd == -1)
1117 		return NULL;
1118 	sock = get_raw_socket(fd);
1119 	if (!IS_ERR(sock))
1120 		return sock;
1121 	sock = get_tap_socket(fd);
1122 	if (!IS_ERR(sock))
1123 		return sock;
1124 	return ERR_PTR(-ENOTSOCK);
1125 }
1126 
1127 static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
1128 {
1129 	struct socket *sock, *oldsock;
1130 	struct vhost_virtqueue *vq;
1131 	struct vhost_net_virtqueue *nvq;
1132 	struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL;
1133 	int r;
1134 
1135 	mutex_lock(&n->dev.mutex);
1136 	r = vhost_dev_check_owner(&n->dev);
1137 	if (r)
1138 		goto err;
1139 
1140 	if (index >= VHOST_NET_VQ_MAX) {
1141 		r = -ENOBUFS;
1142 		goto err;
1143 	}
1144 	vq = &n->vqs[index].vq;
1145 	nvq = &n->vqs[index];
1146 	mutex_lock(&vq->mutex);
1147 
1148 	/* Verify that ring has been setup correctly. */
1149 	if (!vhost_vq_access_ok(vq)) {
1150 		r = -EFAULT;
1151 		goto err_vq;
1152 	}
1153 	sock = get_socket(fd);
1154 	if (IS_ERR(sock)) {
1155 		r = PTR_ERR(sock);
1156 		goto err_vq;
1157 	}
1158 
1159 	/* start polling new socket */
1160 	oldsock = vq->private_data;
1161 	if (sock != oldsock) {
1162 		ubufs = vhost_net_ubuf_alloc(vq,
1163 					     sock && vhost_sock_zcopy(sock));
1164 		if (IS_ERR(ubufs)) {
1165 			r = PTR_ERR(ubufs);
1166 			goto err_ubufs;
1167 		}
1168 
1169 		vhost_net_disable_vq(n, vq);
1170 		vq->private_data = sock;
1171 		vhost_net_buf_unproduce(nvq);
1172 		r = vhost_vq_init_access(vq);
1173 		if (r)
1174 			goto err_used;
1175 		r = vhost_net_enable_vq(n, vq);
1176 		if (r)
1177 			goto err_used;
1178 		if (index == VHOST_NET_VQ_RX)
1179 			nvq->rx_ring = get_tap_ptr_ring(fd);
1180 
1181 		oldubufs = nvq->ubufs;
1182 		nvq->ubufs = ubufs;
1183 
1184 		n->tx_packets = 0;
1185 		n->tx_zcopy_err = 0;
1186 		n->tx_flush = false;
1187 	}
1188 
1189 	mutex_unlock(&vq->mutex);
1190 
1191 	if (oldubufs) {
1192 		vhost_net_ubuf_put_wait_and_free(oldubufs);
1193 		mutex_lock(&vq->mutex);
1194 		vhost_zerocopy_signal_used(n, vq);
1195 		mutex_unlock(&vq->mutex);
1196 	}
1197 
1198 	if (oldsock) {
1199 		vhost_net_flush_vq(n, index);
1200 		sockfd_put(oldsock);
1201 	}
1202 
1203 	mutex_unlock(&n->dev.mutex);
1204 	return 0;
1205 
1206 err_used:
1207 	vq->private_data = oldsock;
1208 	vhost_net_enable_vq(n, vq);
1209 	if (ubufs)
1210 		vhost_net_ubuf_put_wait_and_free(ubufs);
1211 err_ubufs:
1212 	sockfd_put(sock);
1213 err_vq:
1214 	mutex_unlock(&vq->mutex);
1215 err:
1216 	mutex_unlock(&n->dev.mutex);
1217 	return r;
1218 }
1219 
1220 static long vhost_net_reset_owner(struct vhost_net *n)
1221 {
1222 	struct socket *tx_sock = NULL;
1223 	struct socket *rx_sock = NULL;
1224 	long err;
1225 	struct vhost_umem *umem;
1226 
1227 	mutex_lock(&n->dev.mutex);
1228 	err = vhost_dev_check_owner(&n->dev);
1229 	if (err)
1230 		goto done;
1231 	umem = vhost_dev_reset_owner_prepare();
1232 	if (!umem) {
1233 		err = -ENOMEM;
1234 		goto done;
1235 	}
1236 	vhost_net_stop(n, &tx_sock, &rx_sock);
1237 	vhost_net_flush(n);
1238 	vhost_dev_stop(&n->dev);
1239 	vhost_dev_reset_owner(&n->dev, umem);
1240 	vhost_net_vq_reset(n);
1241 done:
1242 	mutex_unlock(&n->dev.mutex);
1243 	if (tx_sock)
1244 		sockfd_put(tx_sock);
1245 	if (rx_sock)
1246 		sockfd_put(rx_sock);
1247 	return err;
1248 }
1249 
1250 static int vhost_net_set_features(struct vhost_net *n, u64 features)
1251 {
1252 	size_t vhost_hlen, sock_hlen, hdr_len;
1253 	int i;
1254 
1255 	hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
1256 			       (1ULL << VIRTIO_F_VERSION_1))) ?
1257 			sizeof(struct virtio_net_hdr_mrg_rxbuf) :
1258 			sizeof(struct virtio_net_hdr);
1259 	if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
1260 		/* vhost provides vnet_hdr */
1261 		vhost_hlen = hdr_len;
1262 		sock_hlen = 0;
1263 	} else {
1264 		/* socket provides vnet_hdr */
1265 		vhost_hlen = 0;
1266 		sock_hlen = hdr_len;
1267 	}
1268 	mutex_lock(&n->dev.mutex);
1269 	if ((features & (1 << VHOST_F_LOG_ALL)) &&
1270 	    !vhost_log_access_ok(&n->dev))
1271 		goto out_unlock;
1272 
1273 	if ((features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))) {
1274 		if (vhost_init_device_iotlb(&n->dev, true))
1275 			goto out_unlock;
1276 	}
1277 
1278 	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
1279 		mutex_lock(&n->vqs[i].vq.mutex);
1280 		n->vqs[i].vq.acked_features = features;
1281 		n->vqs[i].vhost_hlen = vhost_hlen;
1282 		n->vqs[i].sock_hlen = sock_hlen;
1283 		mutex_unlock(&n->vqs[i].vq.mutex);
1284 	}
1285 	mutex_unlock(&n->dev.mutex);
1286 	return 0;
1287 
1288 out_unlock:
1289 	mutex_unlock(&n->dev.mutex);
1290 	return -EFAULT;
1291 }
1292 
1293 static long vhost_net_set_owner(struct vhost_net *n)
1294 {
1295 	int r;
1296 
1297 	mutex_lock(&n->dev.mutex);
1298 	if (vhost_dev_has_owner(&n->dev)) {
1299 		r = -EBUSY;
1300 		goto out;
1301 	}
1302 	r = vhost_net_set_ubuf_info(n);
1303 	if (r)
1304 		goto out;
1305 	r = vhost_dev_set_owner(&n->dev);
1306 	if (r)
1307 		vhost_net_clear_ubuf_info(n);
1308 	vhost_net_flush(n);
1309 out:
1310 	mutex_unlock(&n->dev.mutex);
1311 	return r;
1312 }
1313 
1314 static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
1315 			    unsigned long arg)
1316 {
1317 	struct vhost_net *n = f->private_data;
1318 	void __user *argp = (void __user *)arg;
1319 	u64 __user *featurep = argp;
1320 	struct vhost_vring_file backend;
1321 	u64 features;
1322 	int r;
1323 
1324 	switch (ioctl) {
1325 	case VHOST_NET_SET_BACKEND:
1326 		if (copy_from_user(&backend, argp, sizeof backend))
1327 			return -EFAULT;
1328 		return vhost_net_set_backend(n, backend.index, backend.fd);
1329 	case VHOST_GET_FEATURES:
1330 		features = VHOST_NET_FEATURES;
1331 		if (copy_to_user(featurep, &features, sizeof features))
1332 			return -EFAULT;
1333 		return 0;
1334 	case VHOST_SET_FEATURES:
1335 		if (copy_from_user(&features, featurep, sizeof features))
1336 			return -EFAULT;
1337 		if (features & ~VHOST_NET_FEATURES)
1338 			return -EOPNOTSUPP;
1339 		return vhost_net_set_features(n, features);
1340 	case VHOST_RESET_OWNER:
1341 		return vhost_net_reset_owner(n);
1342 	case VHOST_SET_OWNER:
1343 		return vhost_net_set_owner(n);
1344 	default:
1345 		mutex_lock(&n->dev.mutex);
1346 		r = vhost_dev_ioctl(&n->dev, ioctl, argp);
1347 		if (r == -ENOIOCTLCMD)
1348 			r = vhost_vring_ioctl(&n->dev, ioctl, argp);
1349 		else
1350 			vhost_net_flush(n);
1351 		mutex_unlock(&n->dev.mutex);
1352 		return r;
1353 	}
1354 }
1355 
1356 #ifdef CONFIG_COMPAT
1357 static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
1358 				   unsigned long arg)
1359 {
1360 	return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
1361 }
1362 #endif
1363 
1364 static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
1365 {
1366 	struct file *file = iocb->ki_filp;
1367 	struct vhost_net *n = file->private_data;
1368 	struct vhost_dev *dev = &n->dev;
1369 	int noblock = file->f_flags & O_NONBLOCK;
1370 
1371 	return vhost_chr_read_iter(dev, to, noblock);
1372 }
1373 
1374 static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
1375 					struct iov_iter *from)
1376 {
1377 	struct file *file = iocb->ki_filp;
1378 	struct vhost_net *n = file->private_data;
1379 	struct vhost_dev *dev = &n->dev;
1380 
1381 	return vhost_chr_write_iter(dev, from);
1382 }
1383 
1384 static __poll_t vhost_net_chr_poll(struct file *file, poll_table *wait)
1385 {
1386 	struct vhost_net *n = file->private_data;
1387 	struct vhost_dev *dev = &n->dev;
1388 
1389 	return vhost_chr_poll(file, dev, wait);
1390 }
1391 
1392 static const struct file_operations vhost_net_fops = {
1393 	.owner          = THIS_MODULE,
1394 	.release        = vhost_net_release,
1395 	.read_iter      = vhost_net_chr_read_iter,
1396 	.write_iter     = vhost_net_chr_write_iter,
1397 	.poll           = vhost_net_chr_poll,
1398 	.unlocked_ioctl = vhost_net_ioctl,
1399 #ifdef CONFIG_COMPAT
1400 	.compat_ioctl   = vhost_net_compat_ioctl,
1401 #endif
1402 	.open           = vhost_net_open,
1403 	.llseek		= noop_llseek,
1404 };
1405 
1406 static struct miscdevice vhost_net_misc = {
1407 	.minor = VHOST_NET_MINOR,
1408 	.name = "vhost-net",
1409 	.fops = &vhost_net_fops,
1410 };
1411 
1412 static int vhost_net_init(void)
1413 {
1414 	if (experimental_zcopytx)
1415 		vhost_net_enable_zcopy(VHOST_NET_VQ_TX);
1416 	return misc_register(&vhost_net_misc);
1417 }
1418 module_init(vhost_net_init);
1419 
1420 static void vhost_net_exit(void)
1421 {
1422 	misc_deregister(&vhost_net_misc);
1423 }
1424 module_exit(vhost_net_exit);
1425 
1426 MODULE_VERSION("0.0.1");
1427 MODULE_LICENSE("GPL v2");
1428 MODULE_AUTHOR("Michael S. Tsirkin");
1429 MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
1430 MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR);
1431 MODULE_ALIAS("devname:vhost-net");
1432