xref: /openbmc/linux/drivers/xen/pvcalls-front.c (revision 6d99a79c)
1 /*
2  * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; either version 2 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  */
14 
15 #include <linux/module.h>
16 #include <linux/net.h>
17 #include <linux/socket.h>
18 
19 #include <net/sock.h>
20 
21 #include <xen/events.h>
22 #include <xen/grant_table.h>
23 #include <xen/xen.h>
24 #include <xen/xenbus.h>
25 #include <xen/interface/io/pvcalls.h>
26 
27 #include "pvcalls-front.h"
28 
29 #define PVCALLS_INVALID_ID UINT_MAX
30 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
31 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
32 #define PVCALLS_FRONT_MAX_SPIN 5000
33 
34 struct pvcalls_bedata {
35 	struct xen_pvcalls_front_ring ring;
36 	grant_ref_t ref;
37 	int irq;
38 
39 	struct list_head socket_mappings;
40 	spinlock_t socket_lock;
41 
42 	wait_queue_head_t inflight_req;
43 	struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING];
44 };
45 /* Only one front/back connection supported. */
46 static struct xenbus_device *pvcalls_front_dev;
47 static atomic_t pvcalls_refcount;
48 
49 /* first increment refcount, then proceed */
50 #define pvcalls_enter() {               \
51 	atomic_inc(&pvcalls_refcount);      \
52 }
53 
54 /* first complete other operations, then decrement refcount */
55 #define pvcalls_exit() {                \
56 	atomic_dec(&pvcalls_refcount);      \
57 }
58 
59 struct sock_mapping {
60 	bool active_socket;
61 	struct list_head list;
62 	struct socket *sock;
63 	atomic_t refcount;
64 	union {
65 		struct {
66 			int irq;
67 			grant_ref_t ref;
68 			struct pvcalls_data_intf *ring;
69 			struct pvcalls_data data;
70 			struct mutex in_mutex;
71 			struct mutex out_mutex;
72 
73 			wait_queue_head_t inflight_conn_req;
74 		} active;
75 		struct {
76 		/*
77 		 * Socket status, needs to be 64-bit aligned due to the
78 		 * test_and_* functions which have this requirement on arm64.
79 		 */
80 #define PVCALLS_STATUS_UNINITALIZED  0
81 #define PVCALLS_STATUS_BIND          1
82 #define PVCALLS_STATUS_LISTEN        2
83 			uint8_t status __attribute__((aligned(8)));
84 		/*
85 		 * Internal state-machine flags.
86 		 * Only one accept operation can be inflight for a socket.
87 		 * Only one poll operation can be inflight for a given socket.
88 		 * flags needs to be 64-bit aligned due to the test_and_*
89 		 * functions which have this requirement on arm64.
90 		 */
91 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
92 #define PVCALLS_FLAG_POLL_INFLIGHT   1
93 #define PVCALLS_FLAG_POLL_RET        2
94 			uint8_t flags __attribute__((aligned(8)));
95 			uint32_t inflight_req_id;
96 			struct sock_mapping *accept_map;
97 			wait_queue_head_t inflight_accept_req;
98 		} passive;
99 	};
100 };
101 
102 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
103 {
104 	struct sock_mapping *map;
105 
106 	if (!pvcalls_front_dev ||
107 		dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
108 		return ERR_PTR(-ENOTCONN);
109 
110 	map = (struct sock_mapping *)sock->sk->sk_send_head;
111 	if (map == NULL)
112 		return ERR_PTR(-ENOTSOCK);
113 
114 	pvcalls_enter();
115 	atomic_inc(&map->refcount);
116 	return map;
117 }
118 
119 static inline void pvcalls_exit_sock(struct socket *sock)
120 {
121 	struct sock_mapping *map;
122 
123 	map = (struct sock_mapping *)sock->sk->sk_send_head;
124 	atomic_dec(&map->refcount);
125 	pvcalls_exit();
126 }
127 
128 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
129 {
130 	*req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
131 	if (RING_FULL(&bedata->ring) ||
132 	    bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID)
133 		return -EAGAIN;
134 	return 0;
135 }
136 
137 static bool pvcalls_front_write_todo(struct sock_mapping *map)
138 {
139 	struct pvcalls_data_intf *intf = map->active.ring;
140 	RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
141 	int32_t error;
142 
143 	error = intf->out_error;
144 	if (error == -ENOTCONN)
145 		return false;
146 	if (error != 0)
147 		return true;
148 
149 	cons = intf->out_cons;
150 	prod = intf->out_prod;
151 	return !!(size - pvcalls_queued(prod, cons, size));
152 }
153 
154 static bool pvcalls_front_read_todo(struct sock_mapping *map)
155 {
156 	struct pvcalls_data_intf *intf = map->active.ring;
157 	RING_IDX cons, prod;
158 	int32_t error;
159 
160 	cons = intf->in_cons;
161 	prod = intf->in_prod;
162 	error = intf->in_error;
163 	return (error != 0 ||
164 		pvcalls_queued(prod, cons,
165 			       XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0);
166 }
167 
168 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id)
169 {
170 	struct xenbus_device *dev = dev_id;
171 	struct pvcalls_bedata *bedata;
172 	struct xen_pvcalls_response *rsp;
173 	uint8_t *src, *dst;
174 	int req_id = 0, more = 0, done = 0;
175 
176 	if (dev == NULL)
177 		return IRQ_HANDLED;
178 
179 	pvcalls_enter();
180 	bedata = dev_get_drvdata(&dev->dev);
181 	if (bedata == NULL) {
182 		pvcalls_exit();
183 		return IRQ_HANDLED;
184 	}
185 
186 again:
187 	while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) {
188 		rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons);
189 
190 		req_id = rsp->req_id;
191 		if (rsp->cmd == PVCALLS_POLL) {
192 			struct sock_mapping *map = (struct sock_mapping *)(uintptr_t)
193 						   rsp->u.poll.id;
194 
195 			clear_bit(PVCALLS_FLAG_POLL_INFLIGHT,
196 				  (void *)&map->passive.flags);
197 			/*
198 			 * clear INFLIGHT, then set RET. It pairs with
199 			 * the checks at the beginning of
200 			 * pvcalls_front_poll_passive.
201 			 */
202 			smp_wmb();
203 			set_bit(PVCALLS_FLAG_POLL_RET,
204 				(void *)&map->passive.flags);
205 		} else {
206 			dst = (uint8_t *)&bedata->rsp[req_id] +
207 			      sizeof(rsp->req_id);
208 			src = (uint8_t *)rsp + sizeof(rsp->req_id);
209 			memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id));
210 			/*
211 			 * First copy the rest of the data, then req_id. It is
212 			 * paired with the barrier when accessing bedata->rsp.
213 			 */
214 			smp_wmb();
215 			bedata->rsp[req_id].req_id = req_id;
216 		}
217 
218 		done = 1;
219 		bedata->ring.rsp_cons++;
220 	}
221 
222 	RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more);
223 	if (more)
224 		goto again;
225 	if (done)
226 		wake_up(&bedata->inflight_req);
227 	pvcalls_exit();
228 	return IRQ_HANDLED;
229 }
230 
231 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata,
232 				   struct sock_mapping *map)
233 {
234 	int i;
235 
236 	unbind_from_irqhandler(map->active.irq, map);
237 
238 	spin_lock(&bedata->socket_lock);
239 	if (!list_empty(&map->list))
240 		list_del_init(&map->list);
241 	spin_unlock(&bedata->socket_lock);
242 
243 	for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
244 		gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0);
245 	gnttab_end_foreign_access(map->active.ref, 0, 0);
246 	free_page((unsigned long)map->active.ring);
247 
248 	kfree(map);
249 }
250 
251 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map)
252 {
253 	struct sock_mapping *map = sock_map;
254 
255 	if (map == NULL)
256 		return IRQ_HANDLED;
257 
258 	wake_up_interruptible(&map->active.inflight_conn_req);
259 
260 	return IRQ_HANDLED;
261 }
262 
263 int pvcalls_front_socket(struct socket *sock)
264 {
265 	struct pvcalls_bedata *bedata;
266 	struct sock_mapping *map = NULL;
267 	struct xen_pvcalls_request *req;
268 	int notify, req_id, ret;
269 
270 	/*
271 	 * PVCalls only supports domain AF_INET,
272 	 * type SOCK_STREAM and protocol 0 sockets for now.
273 	 *
274 	 * Check socket type here, AF_INET and protocol checks are done
275 	 * by the caller.
276 	 */
277 	if (sock->type != SOCK_STREAM)
278 		return -EOPNOTSUPP;
279 
280 	pvcalls_enter();
281 	if (!pvcalls_front_dev) {
282 		pvcalls_exit();
283 		return -EACCES;
284 	}
285 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
286 
287 	map = kzalloc(sizeof(*map), GFP_KERNEL);
288 	if (map == NULL) {
289 		pvcalls_exit();
290 		return -ENOMEM;
291 	}
292 
293 	spin_lock(&bedata->socket_lock);
294 
295 	ret = get_request(bedata, &req_id);
296 	if (ret < 0) {
297 		kfree(map);
298 		spin_unlock(&bedata->socket_lock);
299 		pvcalls_exit();
300 		return ret;
301 	}
302 
303 	/*
304 	 * sock->sk->sk_send_head is not used for ip sockets: reuse the
305 	 * field to store a pointer to the struct sock_mapping
306 	 * corresponding to the socket. This way, we can easily get the
307 	 * struct sock_mapping from the struct socket.
308 	 */
309 	sock->sk->sk_send_head = (void *)map;
310 	list_add_tail(&map->list, &bedata->socket_mappings);
311 
312 	req = RING_GET_REQUEST(&bedata->ring, req_id);
313 	req->req_id = req_id;
314 	req->cmd = PVCALLS_SOCKET;
315 	req->u.socket.id = (uintptr_t) map;
316 	req->u.socket.domain = AF_INET;
317 	req->u.socket.type = SOCK_STREAM;
318 	req->u.socket.protocol = IPPROTO_IP;
319 
320 	bedata->ring.req_prod_pvt++;
321 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
322 	spin_unlock(&bedata->socket_lock);
323 	if (notify)
324 		notify_remote_via_irq(bedata->irq);
325 
326 	wait_event(bedata->inflight_req,
327 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
328 
329 	/* read req_id, then the content */
330 	smp_rmb();
331 	ret = bedata->rsp[req_id].ret;
332 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
333 
334 	pvcalls_exit();
335 	return ret;
336 }
337 
338 static int create_active(struct sock_mapping *map, int *evtchn)
339 {
340 	void *bytes;
341 	int ret = -ENOMEM, irq = -1, i;
342 
343 	*evtchn = -1;
344 	init_waitqueue_head(&map->active.inflight_conn_req);
345 
346 	map->active.ring = (struct pvcalls_data_intf *)
347 		__get_free_page(GFP_KERNEL | __GFP_ZERO);
348 	if (map->active.ring == NULL)
349 		goto out_error;
350 	map->active.ring->ring_order = PVCALLS_RING_ORDER;
351 	bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO,
352 					PVCALLS_RING_ORDER);
353 	if (bytes == NULL)
354 		goto out_error;
355 	for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
356 		map->active.ring->ref[i] = gnttab_grant_foreign_access(
357 			pvcalls_front_dev->otherend_id,
358 			pfn_to_gfn(virt_to_pfn(bytes) + i), 0);
359 
360 	map->active.ref = gnttab_grant_foreign_access(
361 		pvcalls_front_dev->otherend_id,
362 		pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0);
363 
364 	map->active.data.in = bytes;
365 	map->active.data.out = bytes +
366 		XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
367 
368 	ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn);
369 	if (ret)
370 		goto out_error;
371 	irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler,
372 					0, "pvcalls-frontend", map);
373 	if (irq < 0) {
374 		ret = irq;
375 		goto out_error;
376 	}
377 
378 	map->active.irq = irq;
379 	map->active_socket = true;
380 	mutex_init(&map->active.in_mutex);
381 	mutex_init(&map->active.out_mutex);
382 
383 	return 0;
384 
385 out_error:
386 	if (*evtchn >= 0)
387 		xenbus_free_evtchn(pvcalls_front_dev, *evtchn);
388 	kfree(map->active.data.in);
389 	kfree(map->active.ring);
390 	return ret;
391 }
392 
393 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
394 				int addr_len, int flags)
395 {
396 	struct pvcalls_bedata *bedata;
397 	struct sock_mapping *map = NULL;
398 	struct xen_pvcalls_request *req;
399 	int notify, req_id, ret, evtchn;
400 
401 	if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
402 		return -EOPNOTSUPP;
403 
404 	map = pvcalls_enter_sock(sock);
405 	if (IS_ERR(map))
406 		return PTR_ERR(map);
407 
408 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
409 
410 	spin_lock(&bedata->socket_lock);
411 	ret = get_request(bedata, &req_id);
412 	if (ret < 0) {
413 		spin_unlock(&bedata->socket_lock);
414 		pvcalls_exit_sock(sock);
415 		return ret;
416 	}
417 	ret = create_active(map, &evtchn);
418 	if (ret < 0) {
419 		spin_unlock(&bedata->socket_lock);
420 		pvcalls_exit_sock(sock);
421 		return ret;
422 	}
423 
424 	req = RING_GET_REQUEST(&bedata->ring, req_id);
425 	req->req_id = req_id;
426 	req->cmd = PVCALLS_CONNECT;
427 	req->u.connect.id = (uintptr_t)map;
428 	req->u.connect.len = addr_len;
429 	req->u.connect.flags = flags;
430 	req->u.connect.ref = map->active.ref;
431 	req->u.connect.evtchn = evtchn;
432 	memcpy(req->u.connect.addr, addr, sizeof(*addr));
433 
434 	map->sock = sock;
435 
436 	bedata->ring.req_prod_pvt++;
437 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
438 	spin_unlock(&bedata->socket_lock);
439 
440 	if (notify)
441 		notify_remote_via_irq(bedata->irq);
442 
443 	wait_event(bedata->inflight_req,
444 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
445 
446 	/* read req_id, then the content */
447 	smp_rmb();
448 	ret = bedata->rsp[req_id].ret;
449 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
450 	pvcalls_exit_sock(sock);
451 	return ret;
452 }
453 
454 static int __write_ring(struct pvcalls_data_intf *intf,
455 			struct pvcalls_data *data,
456 			struct iov_iter *msg_iter,
457 			int len)
458 {
459 	RING_IDX cons, prod, size, masked_prod, masked_cons;
460 	RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
461 	int32_t error;
462 
463 	error = intf->out_error;
464 	if (error < 0)
465 		return error;
466 	cons = intf->out_cons;
467 	prod = intf->out_prod;
468 	/* read indexes before continuing */
469 	virt_mb();
470 
471 	size = pvcalls_queued(prod, cons, array_size);
472 	if (size >= array_size)
473 		return -EINVAL;
474 	if (len > array_size - size)
475 		len = array_size - size;
476 
477 	masked_prod = pvcalls_mask(prod, array_size);
478 	masked_cons = pvcalls_mask(cons, array_size);
479 
480 	if (masked_prod < masked_cons) {
481 		len = copy_from_iter(data->out + masked_prod, len, msg_iter);
482 	} else {
483 		if (len > array_size - masked_prod) {
484 			int ret = copy_from_iter(data->out + masked_prod,
485 				       array_size - masked_prod, msg_iter);
486 			if (ret != array_size - masked_prod) {
487 				len = ret;
488 				goto out;
489 			}
490 			len = ret + copy_from_iter(data->out, len - ret, msg_iter);
491 		} else {
492 			len = copy_from_iter(data->out + masked_prod, len, msg_iter);
493 		}
494 	}
495 out:
496 	/* write to ring before updating pointer */
497 	virt_wmb();
498 	intf->out_prod += len;
499 
500 	return len;
501 }
502 
503 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg,
504 			  size_t len)
505 {
506 	struct pvcalls_bedata *bedata;
507 	struct sock_mapping *map;
508 	int sent, tot_sent = 0;
509 	int count = 0, flags;
510 
511 	flags = msg->msg_flags;
512 	if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
513 		return -EOPNOTSUPP;
514 
515 	map = pvcalls_enter_sock(sock);
516 	if (IS_ERR(map))
517 		return PTR_ERR(map);
518 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
519 
520 	mutex_lock(&map->active.out_mutex);
521 	if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
522 		mutex_unlock(&map->active.out_mutex);
523 		pvcalls_exit_sock(sock);
524 		return -EAGAIN;
525 	}
526 	if (len > INT_MAX)
527 		len = INT_MAX;
528 
529 again:
530 	count++;
531 	sent = __write_ring(map->active.ring,
532 			    &map->active.data, &msg->msg_iter,
533 			    len);
534 	if (sent > 0) {
535 		len -= sent;
536 		tot_sent += sent;
537 		notify_remote_via_irq(map->active.irq);
538 	}
539 	if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN)
540 		goto again;
541 	if (sent < 0)
542 		tot_sent = sent;
543 
544 	mutex_unlock(&map->active.out_mutex);
545 	pvcalls_exit_sock(sock);
546 	return tot_sent;
547 }
548 
549 static int __read_ring(struct pvcalls_data_intf *intf,
550 		       struct pvcalls_data *data,
551 		       struct iov_iter *msg_iter,
552 		       size_t len, int flags)
553 {
554 	RING_IDX cons, prod, size, masked_prod, masked_cons;
555 	RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
556 	int32_t error;
557 
558 	cons = intf->in_cons;
559 	prod = intf->in_prod;
560 	error = intf->in_error;
561 	/* get pointers before reading from the ring */
562 	virt_rmb();
563 	if (error < 0)
564 		return error;
565 
566 	size = pvcalls_queued(prod, cons, array_size);
567 	masked_prod = pvcalls_mask(prod, array_size);
568 	masked_cons = pvcalls_mask(cons, array_size);
569 
570 	if (size == 0)
571 		return 0;
572 
573 	if (len > size)
574 		len = size;
575 
576 	if (masked_prod > masked_cons) {
577 		len = copy_to_iter(data->in + masked_cons, len, msg_iter);
578 	} else {
579 		if (len > (array_size - masked_cons)) {
580 			int ret = copy_to_iter(data->in + masked_cons,
581 				     array_size - masked_cons, msg_iter);
582 			if (ret != array_size - masked_cons) {
583 				len = ret;
584 				goto out;
585 			}
586 			len = ret + copy_to_iter(data->in, len - ret, msg_iter);
587 		} else {
588 			len = copy_to_iter(data->in + masked_cons, len, msg_iter);
589 		}
590 	}
591 out:
592 	/* read data from the ring before increasing the index */
593 	virt_mb();
594 	if (!(flags & MSG_PEEK))
595 		intf->in_cons += len;
596 
597 	return len;
598 }
599 
600 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
601 		     int flags)
602 {
603 	struct pvcalls_bedata *bedata;
604 	int ret;
605 	struct sock_mapping *map;
606 
607 	if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
608 		return -EOPNOTSUPP;
609 
610 	map = pvcalls_enter_sock(sock);
611 	if (IS_ERR(map))
612 		return PTR_ERR(map);
613 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
614 
615 	mutex_lock(&map->active.in_mutex);
616 	if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER))
617 		len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
618 
619 	while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) {
620 		wait_event_interruptible(map->active.inflight_conn_req,
621 					 pvcalls_front_read_todo(map));
622 	}
623 	ret = __read_ring(map->active.ring, &map->active.data,
624 			  &msg->msg_iter, len, flags);
625 
626 	if (ret > 0)
627 		notify_remote_via_irq(map->active.irq);
628 	if (ret == 0)
629 		ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0;
630 	if (ret == -ENOTCONN)
631 		ret = 0;
632 
633 	mutex_unlock(&map->active.in_mutex);
634 	pvcalls_exit_sock(sock);
635 	return ret;
636 }
637 
638 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
639 {
640 	struct pvcalls_bedata *bedata;
641 	struct sock_mapping *map = NULL;
642 	struct xen_pvcalls_request *req;
643 	int notify, req_id, ret;
644 
645 	if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
646 		return -EOPNOTSUPP;
647 
648 	map = pvcalls_enter_sock(sock);
649 	if (IS_ERR(map))
650 		return PTR_ERR(map);
651 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
652 
653 	spin_lock(&bedata->socket_lock);
654 	ret = get_request(bedata, &req_id);
655 	if (ret < 0) {
656 		spin_unlock(&bedata->socket_lock);
657 		pvcalls_exit_sock(sock);
658 		return ret;
659 	}
660 	req = RING_GET_REQUEST(&bedata->ring, req_id);
661 	req->req_id = req_id;
662 	map->sock = sock;
663 	req->cmd = PVCALLS_BIND;
664 	req->u.bind.id = (uintptr_t)map;
665 	memcpy(req->u.bind.addr, addr, sizeof(*addr));
666 	req->u.bind.len = addr_len;
667 
668 	init_waitqueue_head(&map->passive.inflight_accept_req);
669 
670 	map->active_socket = false;
671 
672 	bedata->ring.req_prod_pvt++;
673 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
674 	spin_unlock(&bedata->socket_lock);
675 	if (notify)
676 		notify_remote_via_irq(bedata->irq);
677 
678 	wait_event(bedata->inflight_req,
679 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
680 
681 	/* read req_id, then the content */
682 	smp_rmb();
683 	ret = bedata->rsp[req_id].ret;
684 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
685 
686 	map->passive.status = PVCALLS_STATUS_BIND;
687 	pvcalls_exit_sock(sock);
688 	return 0;
689 }
690 
691 int pvcalls_front_listen(struct socket *sock, int backlog)
692 {
693 	struct pvcalls_bedata *bedata;
694 	struct sock_mapping *map;
695 	struct xen_pvcalls_request *req;
696 	int notify, req_id, ret;
697 
698 	map = pvcalls_enter_sock(sock);
699 	if (IS_ERR(map))
700 		return PTR_ERR(map);
701 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
702 
703 	if (map->passive.status != PVCALLS_STATUS_BIND) {
704 		pvcalls_exit_sock(sock);
705 		return -EOPNOTSUPP;
706 	}
707 
708 	spin_lock(&bedata->socket_lock);
709 	ret = get_request(bedata, &req_id);
710 	if (ret < 0) {
711 		spin_unlock(&bedata->socket_lock);
712 		pvcalls_exit_sock(sock);
713 		return ret;
714 	}
715 	req = RING_GET_REQUEST(&bedata->ring, req_id);
716 	req->req_id = req_id;
717 	req->cmd = PVCALLS_LISTEN;
718 	req->u.listen.id = (uintptr_t) map;
719 	req->u.listen.backlog = backlog;
720 
721 	bedata->ring.req_prod_pvt++;
722 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
723 	spin_unlock(&bedata->socket_lock);
724 	if (notify)
725 		notify_remote_via_irq(bedata->irq);
726 
727 	wait_event(bedata->inflight_req,
728 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
729 
730 	/* read req_id, then the content */
731 	smp_rmb();
732 	ret = bedata->rsp[req_id].ret;
733 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
734 
735 	map->passive.status = PVCALLS_STATUS_LISTEN;
736 	pvcalls_exit_sock(sock);
737 	return ret;
738 }
739 
740 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
741 {
742 	struct pvcalls_bedata *bedata;
743 	struct sock_mapping *map;
744 	struct sock_mapping *map2 = NULL;
745 	struct xen_pvcalls_request *req;
746 	int notify, req_id, ret, evtchn, nonblock;
747 
748 	map = pvcalls_enter_sock(sock);
749 	if (IS_ERR(map))
750 		return PTR_ERR(map);
751 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
752 
753 	if (map->passive.status != PVCALLS_STATUS_LISTEN) {
754 		pvcalls_exit_sock(sock);
755 		return -EINVAL;
756 	}
757 
758 	nonblock = flags & SOCK_NONBLOCK;
759 	/*
760 	 * Backend only supports 1 inflight accept request, will return
761 	 * errors for the others
762 	 */
763 	if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
764 			     (void *)&map->passive.flags)) {
765 		req_id = READ_ONCE(map->passive.inflight_req_id);
766 		if (req_id != PVCALLS_INVALID_ID &&
767 		    READ_ONCE(bedata->rsp[req_id].req_id) == req_id) {
768 			map2 = map->passive.accept_map;
769 			goto received;
770 		}
771 		if (nonblock) {
772 			pvcalls_exit_sock(sock);
773 			return -EAGAIN;
774 		}
775 		if (wait_event_interruptible(map->passive.inflight_accept_req,
776 			!test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
777 					  (void *)&map->passive.flags))) {
778 			pvcalls_exit_sock(sock);
779 			return -EINTR;
780 		}
781 	}
782 
783 	spin_lock(&bedata->socket_lock);
784 	ret = get_request(bedata, &req_id);
785 	if (ret < 0) {
786 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
787 			  (void *)&map->passive.flags);
788 		spin_unlock(&bedata->socket_lock);
789 		pvcalls_exit_sock(sock);
790 		return ret;
791 	}
792 	map2 = kzalloc(sizeof(*map2), GFP_ATOMIC);
793 	if (map2 == NULL) {
794 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
795 			  (void *)&map->passive.flags);
796 		spin_unlock(&bedata->socket_lock);
797 		pvcalls_exit_sock(sock);
798 		return -ENOMEM;
799 	}
800 	ret = create_active(map2, &evtchn);
801 	if (ret < 0) {
802 		kfree(map2);
803 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
804 			  (void *)&map->passive.flags);
805 		spin_unlock(&bedata->socket_lock);
806 		pvcalls_exit_sock(sock);
807 		return ret;
808 	}
809 	list_add_tail(&map2->list, &bedata->socket_mappings);
810 
811 	req = RING_GET_REQUEST(&bedata->ring, req_id);
812 	req->req_id = req_id;
813 	req->cmd = PVCALLS_ACCEPT;
814 	req->u.accept.id = (uintptr_t) map;
815 	req->u.accept.ref = map2->active.ref;
816 	req->u.accept.id_new = (uintptr_t) map2;
817 	req->u.accept.evtchn = evtchn;
818 	map->passive.accept_map = map2;
819 
820 	bedata->ring.req_prod_pvt++;
821 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
822 	spin_unlock(&bedata->socket_lock);
823 	if (notify)
824 		notify_remote_via_irq(bedata->irq);
825 	/* We could check if we have received a response before returning. */
826 	if (nonblock) {
827 		WRITE_ONCE(map->passive.inflight_req_id, req_id);
828 		pvcalls_exit_sock(sock);
829 		return -EAGAIN;
830 	}
831 
832 	if (wait_event_interruptible(bedata->inflight_req,
833 		READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
834 		pvcalls_exit_sock(sock);
835 		return -EINTR;
836 	}
837 	/* read req_id, then the content */
838 	smp_rmb();
839 
840 received:
841 	map2->sock = newsock;
842 	newsock->sk = kzalloc(sizeof(*newsock->sk), GFP_KERNEL);
843 	if (!newsock->sk) {
844 		bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
845 		map->passive.inflight_req_id = PVCALLS_INVALID_ID;
846 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
847 			  (void *)&map->passive.flags);
848 		pvcalls_front_free_map(bedata, map2);
849 		pvcalls_exit_sock(sock);
850 		return -ENOMEM;
851 	}
852 	newsock->sk->sk_send_head = (void *)map2;
853 
854 	ret = bedata->rsp[req_id].ret;
855 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
856 	map->passive.inflight_req_id = PVCALLS_INVALID_ID;
857 
858 	clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
859 	wake_up(&map->passive.inflight_accept_req);
860 
861 	pvcalls_exit_sock(sock);
862 	return ret;
863 }
864 
865 static __poll_t pvcalls_front_poll_passive(struct file *file,
866 					       struct pvcalls_bedata *bedata,
867 					       struct sock_mapping *map,
868 					       poll_table *wait)
869 {
870 	int notify, req_id, ret;
871 	struct xen_pvcalls_request *req;
872 
873 	if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
874 		     (void *)&map->passive.flags)) {
875 		uint32_t req_id = READ_ONCE(map->passive.inflight_req_id);
876 
877 		if (req_id != PVCALLS_INVALID_ID &&
878 		    READ_ONCE(bedata->rsp[req_id].req_id) == req_id)
879 			return EPOLLIN | EPOLLRDNORM;
880 
881 		poll_wait(file, &map->passive.inflight_accept_req, wait);
882 		return 0;
883 	}
884 
885 	if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET,
886 			       (void *)&map->passive.flags))
887 		return EPOLLIN | EPOLLRDNORM;
888 
889 	/*
890 	 * First check RET, then INFLIGHT. No barriers necessary to
891 	 * ensure execution ordering because of the conditional
892 	 * instructions creating control dependencies.
893 	 */
894 
895 	if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT,
896 			     (void *)&map->passive.flags)) {
897 		poll_wait(file, &bedata->inflight_req, wait);
898 		return 0;
899 	}
900 
901 	spin_lock(&bedata->socket_lock);
902 	ret = get_request(bedata, &req_id);
903 	if (ret < 0) {
904 		spin_unlock(&bedata->socket_lock);
905 		return ret;
906 	}
907 	req = RING_GET_REQUEST(&bedata->ring, req_id);
908 	req->req_id = req_id;
909 	req->cmd = PVCALLS_POLL;
910 	req->u.poll.id = (uintptr_t) map;
911 
912 	bedata->ring.req_prod_pvt++;
913 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
914 	spin_unlock(&bedata->socket_lock);
915 	if (notify)
916 		notify_remote_via_irq(bedata->irq);
917 
918 	poll_wait(file, &bedata->inflight_req, wait);
919 	return 0;
920 }
921 
922 static __poll_t pvcalls_front_poll_active(struct file *file,
923 					      struct pvcalls_bedata *bedata,
924 					      struct sock_mapping *map,
925 					      poll_table *wait)
926 {
927 	__poll_t mask = 0;
928 	int32_t in_error, out_error;
929 	struct pvcalls_data_intf *intf = map->active.ring;
930 
931 	out_error = intf->out_error;
932 	in_error = intf->in_error;
933 
934 	poll_wait(file, &map->active.inflight_conn_req, wait);
935 	if (pvcalls_front_write_todo(map))
936 		mask |= EPOLLOUT | EPOLLWRNORM;
937 	if (pvcalls_front_read_todo(map))
938 		mask |= EPOLLIN | EPOLLRDNORM;
939 	if (in_error != 0 || out_error != 0)
940 		mask |= EPOLLERR;
941 
942 	return mask;
943 }
944 
945 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock,
946 			       poll_table *wait)
947 {
948 	struct pvcalls_bedata *bedata;
949 	struct sock_mapping *map;
950 	__poll_t ret;
951 
952 	map = pvcalls_enter_sock(sock);
953 	if (IS_ERR(map))
954 		return EPOLLNVAL;
955 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
956 
957 	if (map->active_socket)
958 		ret = pvcalls_front_poll_active(file, bedata, map, wait);
959 	else
960 		ret = pvcalls_front_poll_passive(file, bedata, map, wait);
961 	pvcalls_exit_sock(sock);
962 	return ret;
963 }
964 
965 int pvcalls_front_release(struct socket *sock)
966 {
967 	struct pvcalls_bedata *bedata;
968 	struct sock_mapping *map;
969 	int req_id, notify, ret;
970 	struct xen_pvcalls_request *req;
971 
972 	if (sock->sk == NULL)
973 		return 0;
974 
975 	map = pvcalls_enter_sock(sock);
976 	if (IS_ERR(map)) {
977 		if (PTR_ERR(map) == -ENOTCONN)
978 			return -EIO;
979 		else
980 			return 0;
981 	}
982 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
983 
984 	spin_lock(&bedata->socket_lock);
985 	ret = get_request(bedata, &req_id);
986 	if (ret < 0) {
987 		spin_unlock(&bedata->socket_lock);
988 		pvcalls_exit_sock(sock);
989 		return ret;
990 	}
991 	sock->sk->sk_send_head = NULL;
992 
993 	req = RING_GET_REQUEST(&bedata->ring, req_id);
994 	req->req_id = req_id;
995 	req->cmd = PVCALLS_RELEASE;
996 	req->u.release.id = (uintptr_t)map;
997 
998 	bedata->ring.req_prod_pvt++;
999 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
1000 	spin_unlock(&bedata->socket_lock);
1001 	if (notify)
1002 		notify_remote_via_irq(bedata->irq);
1003 
1004 	wait_event(bedata->inflight_req,
1005 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
1006 
1007 	if (map->active_socket) {
1008 		/*
1009 		 * Set in_error and wake up inflight_conn_req to force
1010 		 * recvmsg waiters to exit.
1011 		 */
1012 		map->active.ring->in_error = -EBADF;
1013 		wake_up_interruptible(&map->active.inflight_conn_req);
1014 
1015 		/*
1016 		 * We need to make sure that sendmsg/recvmsg on this socket have
1017 		 * not started before we've cleared sk_send_head here. The
1018 		 * easiest way to guarantee this is to see that no pvcalls
1019 		 * (other than us) is in progress on this socket.
1020 		 */
1021 		while (atomic_read(&map->refcount) > 1)
1022 			cpu_relax();
1023 
1024 		pvcalls_front_free_map(bedata, map);
1025 	} else {
1026 		wake_up(&bedata->inflight_req);
1027 		wake_up(&map->passive.inflight_accept_req);
1028 
1029 		while (atomic_read(&map->refcount) > 1)
1030 			cpu_relax();
1031 
1032 		spin_lock(&bedata->socket_lock);
1033 		list_del(&map->list);
1034 		spin_unlock(&bedata->socket_lock);
1035 		if (READ_ONCE(map->passive.inflight_req_id) !=
1036 		    PVCALLS_INVALID_ID) {
1037 			pvcalls_front_free_map(bedata,
1038 					       map->passive.accept_map);
1039 		}
1040 		kfree(map);
1041 	}
1042 	WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID);
1043 
1044 	pvcalls_exit();
1045 	return 0;
1046 }
1047 
1048 static const struct xenbus_device_id pvcalls_front_ids[] = {
1049 	{ "pvcalls" },
1050 	{ "" }
1051 };
1052 
1053 static int pvcalls_front_remove(struct xenbus_device *dev)
1054 {
1055 	struct pvcalls_bedata *bedata;
1056 	struct sock_mapping *map = NULL, *n;
1057 
1058 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1059 	dev_set_drvdata(&dev->dev, NULL);
1060 	pvcalls_front_dev = NULL;
1061 	if (bedata->irq >= 0)
1062 		unbind_from_irqhandler(bedata->irq, dev);
1063 
1064 	list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1065 		map->sock->sk->sk_send_head = NULL;
1066 		if (map->active_socket) {
1067 			map->active.ring->in_error = -EBADF;
1068 			wake_up_interruptible(&map->active.inflight_conn_req);
1069 		}
1070 	}
1071 
1072 	smp_mb();
1073 	while (atomic_read(&pvcalls_refcount) > 0)
1074 		cpu_relax();
1075 	list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1076 		if (map->active_socket) {
1077 			/* No need to lock, refcount is 0 */
1078 			pvcalls_front_free_map(bedata, map);
1079 		} else {
1080 			list_del(&map->list);
1081 			kfree(map);
1082 		}
1083 	}
1084 	if (bedata->ref != -1)
1085 		gnttab_end_foreign_access(bedata->ref, 0, 0);
1086 	kfree(bedata->ring.sring);
1087 	kfree(bedata);
1088 	xenbus_switch_state(dev, XenbusStateClosed);
1089 	return 0;
1090 }
1091 
1092 static int pvcalls_front_probe(struct xenbus_device *dev,
1093 			  const struct xenbus_device_id *id)
1094 {
1095 	int ret = -ENOMEM, evtchn, i;
1096 	unsigned int max_page_order, function_calls, len;
1097 	char *versions;
1098 	grant_ref_t gref_head = 0;
1099 	struct xenbus_transaction xbt;
1100 	struct pvcalls_bedata *bedata = NULL;
1101 	struct xen_pvcalls_sring *sring;
1102 
1103 	if (pvcalls_front_dev != NULL) {
1104 		dev_err(&dev->dev, "only one PV Calls connection supported\n");
1105 		return -EINVAL;
1106 	}
1107 
1108 	versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len);
1109 	if (IS_ERR(versions))
1110 		return PTR_ERR(versions);
1111 	if (!len)
1112 		return -EINVAL;
1113 	if (strcmp(versions, "1")) {
1114 		kfree(versions);
1115 		return -EINVAL;
1116 	}
1117 	kfree(versions);
1118 	max_page_order = xenbus_read_unsigned(dev->otherend,
1119 					      "max-page-order", 0);
1120 	if (max_page_order < PVCALLS_RING_ORDER)
1121 		return -ENODEV;
1122 	function_calls = xenbus_read_unsigned(dev->otherend,
1123 					      "function-calls", 0);
1124 	/* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
1125 	if (function_calls != 1)
1126 		return -ENODEV;
1127 	pr_info("%s max-page-order is %u\n", __func__, max_page_order);
1128 
1129 	bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL);
1130 	if (!bedata)
1131 		return -ENOMEM;
1132 
1133 	dev_set_drvdata(&dev->dev, bedata);
1134 	pvcalls_front_dev = dev;
1135 	init_waitqueue_head(&bedata->inflight_req);
1136 	INIT_LIST_HEAD(&bedata->socket_mappings);
1137 	spin_lock_init(&bedata->socket_lock);
1138 	bedata->irq = -1;
1139 	bedata->ref = -1;
1140 
1141 	for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++)
1142 		bedata->rsp[i].req_id = PVCALLS_INVALID_ID;
1143 
1144 	sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL |
1145 							     __GFP_ZERO);
1146 	if (!sring)
1147 		goto error;
1148 	SHARED_RING_INIT(sring);
1149 	FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE);
1150 
1151 	ret = xenbus_alloc_evtchn(dev, &evtchn);
1152 	if (ret)
1153 		goto error;
1154 
1155 	bedata->irq = bind_evtchn_to_irqhandler(evtchn,
1156 						pvcalls_front_event_handler,
1157 						0, "pvcalls-frontend", dev);
1158 	if (bedata->irq < 0) {
1159 		ret = bedata->irq;
1160 		goto error;
1161 	}
1162 
1163 	ret = gnttab_alloc_grant_references(1, &gref_head);
1164 	if (ret < 0)
1165 		goto error;
1166 	ret = gnttab_claim_grant_reference(&gref_head);
1167 	if (ret < 0)
1168 		goto error;
1169 	bedata->ref = ret;
1170 	gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id,
1171 					virt_to_gfn((void *)sring), 0);
1172 
1173  again:
1174 	ret = xenbus_transaction_start(&xbt);
1175 	if (ret) {
1176 		xenbus_dev_fatal(dev, ret, "starting transaction");
1177 		goto error;
1178 	}
1179 	ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1);
1180 	if (ret)
1181 		goto error_xenbus;
1182 	ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref);
1183 	if (ret)
1184 		goto error_xenbus;
1185 	ret = xenbus_printf(xbt, dev->nodename, "port", "%u",
1186 			    evtchn);
1187 	if (ret)
1188 		goto error_xenbus;
1189 	ret = xenbus_transaction_end(xbt, 0);
1190 	if (ret) {
1191 		if (ret == -EAGAIN)
1192 			goto again;
1193 		xenbus_dev_fatal(dev, ret, "completing transaction");
1194 		goto error;
1195 	}
1196 	xenbus_switch_state(dev, XenbusStateInitialised);
1197 
1198 	return 0;
1199 
1200  error_xenbus:
1201 	xenbus_transaction_end(xbt, 1);
1202 	xenbus_dev_fatal(dev, ret, "writing xenstore");
1203  error:
1204 	pvcalls_front_remove(dev);
1205 	return ret;
1206 }
1207 
1208 static void pvcalls_front_changed(struct xenbus_device *dev,
1209 			    enum xenbus_state backend_state)
1210 {
1211 	switch (backend_state) {
1212 	case XenbusStateReconfiguring:
1213 	case XenbusStateReconfigured:
1214 	case XenbusStateInitialising:
1215 	case XenbusStateInitialised:
1216 	case XenbusStateUnknown:
1217 		break;
1218 
1219 	case XenbusStateInitWait:
1220 		break;
1221 
1222 	case XenbusStateConnected:
1223 		xenbus_switch_state(dev, XenbusStateConnected);
1224 		break;
1225 
1226 	case XenbusStateClosed:
1227 		if (dev->state == XenbusStateClosed)
1228 			break;
1229 		/* Missed the backend's CLOSING state */
1230 		/* fall through */
1231 	case XenbusStateClosing:
1232 		xenbus_frontend_closed(dev);
1233 		break;
1234 	}
1235 }
1236 
1237 static struct xenbus_driver pvcalls_front_driver = {
1238 	.ids = pvcalls_front_ids,
1239 	.probe = pvcalls_front_probe,
1240 	.remove = pvcalls_front_remove,
1241 	.otherend_changed = pvcalls_front_changed,
1242 };
1243 
1244 static int __init pvcalls_frontend_init(void)
1245 {
1246 	if (!xen_domain())
1247 		return -ENODEV;
1248 
1249 	pr_info("Initialising Xen pvcalls frontend driver\n");
1250 
1251 	return xenbus_register_frontend(&pvcalls_front_driver);
1252 }
1253 
1254 module_init(pvcalls_frontend_init);
1255 
1256 MODULE_DESCRIPTION("Xen PV Calls frontend driver");
1257 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1258 MODULE_LICENSE("GPL");
1259