xref: /openbmc/linux/net/kcm/kcmsock.c (revision 56a0eccd)
1 #include <linux/bpf.h>
2 #include <linux/errno.h>
3 #include <linux/errqueue.h>
4 #include <linux/file.h>
5 #include <linux/in.h>
6 #include <linux/kernel.h>
7 #include <linux/module.h>
8 #include <linux/net.h>
9 #include <linux/netdevice.h>
10 #include <linux/poll.h>
11 #include <linux/rculist.h>
12 #include <linux/skbuff.h>
13 #include <linux/socket.h>
14 #include <linux/uaccess.h>
15 #include <linux/workqueue.h>
16 #include <net/kcm.h>
17 #include <net/netns/generic.h>
18 #include <net/sock.h>
19 #include <net/tcp.h>
20 #include <uapi/linux/kcm.h>
21 
22 unsigned int kcm_net_id;
23 
24 static struct kmem_cache *kcm_psockp __read_mostly;
25 static struct kmem_cache *kcm_muxp __read_mostly;
26 static struct workqueue_struct *kcm_wq;
27 
28 static inline struct kcm_sock *kcm_sk(const struct sock *sk)
29 {
30 	return (struct kcm_sock *)sk;
31 }
32 
33 static inline struct kcm_tx_msg *kcm_tx_msg(struct sk_buff *skb)
34 {
35 	return (struct kcm_tx_msg *)skb->cb;
36 }
37 
38 static inline struct kcm_rx_msg *kcm_rx_msg(struct sk_buff *skb)
39 {
40 	return (struct kcm_rx_msg *)((void *)skb->cb +
41 				     offsetof(struct qdisc_skb_cb, data));
42 }
43 
44 static void report_csk_error(struct sock *csk, int err)
45 {
46 	csk->sk_err = EPIPE;
47 	csk->sk_error_report(csk);
48 }
49 
50 /* Callback lock held */
51 static void kcm_abort_rx_psock(struct kcm_psock *psock, int err,
52 			       struct sk_buff *skb)
53 {
54 	struct sock *csk = psock->sk;
55 
56 	/* Unrecoverable error in receive */
57 
58 	del_timer(&psock->rx_msg_timer);
59 
60 	if (psock->rx_stopped)
61 		return;
62 
63 	psock->rx_stopped = 1;
64 	KCM_STATS_INCR(psock->stats.rx_aborts);
65 
66 	/* Report an error on the lower socket */
67 	report_csk_error(csk, err);
68 }
69 
70 static void kcm_abort_tx_psock(struct kcm_psock *psock, int err,
71 			       bool wakeup_kcm)
72 {
73 	struct sock *csk = psock->sk;
74 	struct kcm_mux *mux = psock->mux;
75 
76 	/* Unrecoverable error in transmit */
77 
78 	spin_lock_bh(&mux->lock);
79 
80 	if (psock->tx_stopped) {
81 		spin_unlock_bh(&mux->lock);
82 		return;
83 	}
84 
85 	psock->tx_stopped = 1;
86 	KCM_STATS_INCR(psock->stats.tx_aborts);
87 
88 	if (!psock->tx_kcm) {
89 		/* Take off psocks_avail list */
90 		list_del(&psock->psock_avail_list);
91 	} else if (wakeup_kcm) {
92 		/* In this case psock is being aborted while outside of
93 		 * write_msgs and psock is reserved. Schedule tx_work
94 		 * to handle the failure there. Need to commit tx_stopped
95 		 * before queuing work.
96 		 */
97 		smp_mb();
98 
99 		queue_work(kcm_wq, &psock->tx_kcm->tx_work);
100 	}
101 
102 	spin_unlock_bh(&mux->lock);
103 
104 	/* Report error on lower socket */
105 	report_csk_error(csk, err);
106 }
107 
108 /* RX mux lock held. */
109 static void kcm_update_rx_mux_stats(struct kcm_mux *mux,
110 				    struct kcm_psock *psock)
111 {
112 	KCM_STATS_ADD(mux->stats.rx_bytes,
113 		      psock->stats.rx_bytes - psock->saved_rx_bytes);
114 	mux->stats.rx_msgs +=
115 		psock->stats.rx_msgs - psock->saved_rx_msgs;
116 	psock->saved_rx_msgs = psock->stats.rx_msgs;
117 	psock->saved_rx_bytes = psock->stats.rx_bytes;
118 }
119 
120 static void kcm_update_tx_mux_stats(struct kcm_mux *mux,
121 				    struct kcm_psock *psock)
122 {
123 	KCM_STATS_ADD(mux->stats.tx_bytes,
124 		      psock->stats.tx_bytes - psock->saved_tx_bytes);
125 	mux->stats.tx_msgs +=
126 		psock->stats.tx_msgs - psock->saved_tx_msgs;
127 	psock->saved_tx_msgs = psock->stats.tx_msgs;
128 	psock->saved_tx_bytes = psock->stats.tx_bytes;
129 }
130 
131 static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
132 
133 /* KCM is ready to receive messages on its queue-- either the KCM is new or
134  * has become unblocked after being blocked on full socket buffer. Queue any
135  * pending ready messages on a psock. RX mux lock held.
136  */
137 static void kcm_rcv_ready(struct kcm_sock *kcm)
138 {
139 	struct kcm_mux *mux = kcm->mux;
140 	struct kcm_psock *psock;
141 	struct sk_buff *skb;
142 
143 	if (unlikely(kcm->rx_wait || kcm->rx_psock || kcm->rx_disabled))
144 		return;
145 
146 	while (unlikely((skb = __skb_dequeue(&mux->rx_hold_queue)))) {
147 		if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
148 			/* Assuming buffer limit has been reached */
149 			skb_queue_head(&mux->rx_hold_queue, skb);
150 			WARN_ON(!sk_rmem_alloc_get(&kcm->sk));
151 			return;
152 		}
153 	}
154 
155 	while (!list_empty(&mux->psocks_ready)) {
156 		psock = list_first_entry(&mux->psocks_ready, struct kcm_psock,
157 					 psock_ready_list);
158 
159 		if (kcm_queue_rcv_skb(&kcm->sk, psock->ready_rx_msg)) {
160 			/* Assuming buffer limit has been reached */
161 			WARN_ON(!sk_rmem_alloc_get(&kcm->sk));
162 			return;
163 		}
164 
165 		/* Consumed the ready message on the psock. Schedule rx_work to
166 		 * get more messages.
167 		 */
168 		list_del(&psock->psock_ready_list);
169 		psock->ready_rx_msg = NULL;
170 
171 		/* Commit clearing of ready_rx_msg for queuing work */
172 		smp_mb();
173 
174 		queue_work(kcm_wq, &psock->rx_work);
175 	}
176 
177 	/* Buffer limit is okay now, add to ready list */
178 	list_add_tail(&kcm->wait_rx_list,
179 		      &kcm->mux->kcm_rx_waiters);
180 	kcm->rx_wait = true;
181 }
182 
183 static void kcm_rfree(struct sk_buff *skb)
184 {
185 	struct sock *sk = skb->sk;
186 	struct kcm_sock *kcm = kcm_sk(sk);
187 	struct kcm_mux *mux = kcm->mux;
188 	unsigned int len = skb->truesize;
189 
190 	sk_mem_uncharge(sk, len);
191 	atomic_sub(len, &sk->sk_rmem_alloc);
192 
193 	/* For reading rx_wait and rx_psock without holding lock */
194 	smp_mb__after_atomic();
195 
196 	if (!kcm->rx_wait && !kcm->rx_psock &&
197 	    sk_rmem_alloc_get(sk) < sk->sk_rcvlowat) {
198 		spin_lock_bh(&mux->rx_lock);
199 		kcm_rcv_ready(kcm);
200 		spin_unlock_bh(&mux->rx_lock);
201 	}
202 }
203 
204 static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
205 {
206 	struct sk_buff_head *list = &sk->sk_receive_queue;
207 
208 	if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf)
209 		return -ENOMEM;
210 
211 	if (!sk_rmem_schedule(sk, skb, skb->truesize))
212 		return -ENOBUFS;
213 
214 	skb->dev = NULL;
215 
216 	skb_orphan(skb);
217 	skb->sk = sk;
218 	skb->destructor = kcm_rfree;
219 	atomic_add(skb->truesize, &sk->sk_rmem_alloc);
220 	sk_mem_charge(sk, skb->truesize);
221 
222 	skb_queue_tail(list, skb);
223 
224 	if (!sock_flag(sk, SOCK_DEAD))
225 		sk->sk_data_ready(sk);
226 
227 	return 0;
228 }
229 
230 /* Requeue received messages for a kcm socket to other kcm sockets. This is
231  * called with a kcm socket is receive disabled.
232  * RX mux lock held.
233  */
234 static void requeue_rx_msgs(struct kcm_mux *mux, struct sk_buff_head *head)
235 {
236 	struct sk_buff *skb;
237 	struct kcm_sock *kcm;
238 
239 	while ((skb = __skb_dequeue(head))) {
240 		/* Reset destructor to avoid calling kcm_rcv_ready */
241 		skb->destructor = sock_rfree;
242 		skb_orphan(skb);
243 try_again:
244 		if (list_empty(&mux->kcm_rx_waiters)) {
245 			skb_queue_tail(&mux->rx_hold_queue, skb);
246 			continue;
247 		}
248 
249 		kcm = list_first_entry(&mux->kcm_rx_waiters,
250 				       struct kcm_sock, wait_rx_list);
251 
252 		if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
253 			/* Should mean socket buffer full */
254 			list_del(&kcm->wait_rx_list);
255 			kcm->rx_wait = false;
256 
257 			/* Commit rx_wait to read in kcm_free */
258 			smp_wmb();
259 
260 			goto try_again;
261 		}
262 	}
263 }
264 
265 /* Lower sock lock held */
266 static struct kcm_sock *reserve_rx_kcm(struct kcm_psock *psock,
267 				       struct sk_buff *head)
268 {
269 	struct kcm_mux *mux = psock->mux;
270 	struct kcm_sock *kcm;
271 
272 	WARN_ON(psock->ready_rx_msg);
273 
274 	if (psock->rx_kcm)
275 		return psock->rx_kcm;
276 
277 	spin_lock_bh(&mux->rx_lock);
278 
279 	if (psock->rx_kcm) {
280 		spin_unlock_bh(&mux->rx_lock);
281 		return psock->rx_kcm;
282 	}
283 
284 	kcm_update_rx_mux_stats(mux, psock);
285 
286 	if (list_empty(&mux->kcm_rx_waiters)) {
287 		psock->ready_rx_msg = head;
288 		list_add_tail(&psock->psock_ready_list,
289 			      &mux->psocks_ready);
290 		spin_unlock_bh(&mux->rx_lock);
291 		return NULL;
292 	}
293 
294 	kcm = list_first_entry(&mux->kcm_rx_waiters,
295 			       struct kcm_sock, wait_rx_list);
296 	list_del(&kcm->wait_rx_list);
297 	kcm->rx_wait = false;
298 
299 	psock->rx_kcm = kcm;
300 	kcm->rx_psock = psock;
301 
302 	spin_unlock_bh(&mux->rx_lock);
303 
304 	return kcm;
305 }
306 
307 static void kcm_done(struct kcm_sock *kcm);
308 
309 static void kcm_done_work(struct work_struct *w)
310 {
311 	kcm_done(container_of(w, struct kcm_sock, done_work));
312 }
313 
314 /* Lower sock held */
315 static void unreserve_rx_kcm(struct kcm_psock *psock,
316 			     bool rcv_ready)
317 {
318 	struct kcm_sock *kcm = psock->rx_kcm;
319 	struct kcm_mux *mux = psock->mux;
320 
321 	if (!kcm)
322 		return;
323 
324 	spin_lock_bh(&mux->rx_lock);
325 
326 	psock->rx_kcm = NULL;
327 	kcm->rx_psock = NULL;
328 
329 	/* Commit kcm->rx_psock before sk_rmem_alloc_get to sync with
330 	 * kcm_rfree
331 	 */
332 	smp_mb();
333 
334 	if (unlikely(kcm->done)) {
335 		spin_unlock_bh(&mux->rx_lock);
336 
337 		/* Need to run kcm_done in a task since we need to qcquire
338 		 * callback locks which may already be held here.
339 		 */
340 		INIT_WORK(&kcm->done_work, kcm_done_work);
341 		schedule_work(&kcm->done_work);
342 		return;
343 	}
344 
345 	if (unlikely(kcm->rx_disabled)) {
346 		requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue);
347 	} else if (rcv_ready || unlikely(!sk_rmem_alloc_get(&kcm->sk))) {
348 		/* Check for degenerative race with rx_wait that all
349 		 * data was dequeued (accounted for in kcm_rfree).
350 		 */
351 		kcm_rcv_ready(kcm);
352 	}
353 	spin_unlock_bh(&mux->rx_lock);
354 }
355 
356 static void kcm_start_rx_timer(struct kcm_psock *psock)
357 {
358 	if (psock->sk->sk_rcvtimeo)
359 		mod_timer(&psock->rx_msg_timer, psock->sk->sk_rcvtimeo);
360 }
361 
362 /* Macro to invoke filter function. */
363 #define KCM_RUN_FILTER(prog, ctx) \
364 	(*prog->bpf_func)(ctx, prog->insnsi)
365 
366 /* Lower socket lock held */
367 static int kcm_tcp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
368 			unsigned int orig_offset, size_t orig_len)
369 {
370 	struct kcm_psock *psock = (struct kcm_psock *)desc->arg.data;
371 	struct kcm_rx_msg *rxm;
372 	struct kcm_sock *kcm;
373 	struct sk_buff *head, *skb;
374 	size_t eaten = 0, cand_len;
375 	ssize_t extra;
376 	int err;
377 	bool cloned_orig = false;
378 
379 	if (psock->ready_rx_msg)
380 		return 0;
381 
382 	head = psock->rx_skb_head;
383 	if (head) {
384 		/* Message already in progress */
385 
386 		rxm = kcm_rx_msg(head);
387 		if (unlikely(rxm->early_eaten)) {
388 			/* Already some number of bytes on the receive sock
389 			 * data saved in rx_skb_head, just indicate they
390 			 * are consumed.
391 			 */
392 			eaten = orig_len <= rxm->early_eaten ?
393 				orig_len : rxm->early_eaten;
394 			rxm->early_eaten -= eaten;
395 
396 			return eaten;
397 		}
398 
399 		if (unlikely(orig_offset)) {
400 			/* Getting data with a non-zero offset when a message is
401 			 * in progress is not expected. If it does happen, we
402 			 * need to clone and pull since we can't deal with
403 			 * offsets in the skbs for a message expect in the head.
404 			 */
405 			orig_skb = skb_clone(orig_skb, GFP_ATOMIC);
406 			if (!orig_skb) {
407 				KCM_STATS_INCR(psock->stats.rx_mem_fail);
408 				desc->error = -ENOMEM;
409 				return 0;
410 			}
411 			if (!pskb_pull(orig_skb, orig_offset)) {
412 				KCM_STATS_INCR(psock->stats.rx_mem_fail);
413 				kfree_skb(orig_skb);
414 				desc->error = -ENOMEM;
415 				return 0;
416 			}
417 			cloned_orig = true;
418 			orig_offset = 0;
419 		}
420 
421 		if (!psock->rx_skb_nextp) {
422 			/* We are going to append to the frags_list of head.
423 			 * Need to unshare the frag_list.
424 			 */
425 			err = skb_unclone(head, GFP_ATOMIC);
426 			if (err) {
427 				KCM_STATS_INCR(psock->stats.rx_mem_fail);
428 				desc->error = err;
429 				return 0;
430 			}
431 
432 			if (unlikely(skb_shinfo(head)->frag_list)) {
433 				/* We can't append to an sk_buff that already
434 				 * has a frag_list. We create a new head, point
435 				 * the frag_list of that to the old head, and
436 				 * then are able to use the old head->next for
437 				 * appending to the message.
438 				 */
439 				if (WARN_ON(head->next)) {
440 					desc->error = -EINVAL;
441 					return 0;
442 				}
443 
444 				skb = alloc_skb(0, GFP_ATOMIC);
445 				if (!skb) {
446 					KCM_STATS_INCR(psock->stats.rx_mem_fail);
447 					desc->error = -ENOMEM;
448 					return 0;
449 				}
450 				skb->len = head->len;
451 				skb->data_len = head->len;
452 				skb->truesize = head->truesize;
453 				*kcm_rx_msg(skb) = *kcm_rx_msg(head);
454 				psock->rx_skb_nextp = &head->next;
455 				skb_shinfo(skb)->frag_list = head;
456 				psock->rx_skb_head = skb;
457 				head = skb;
458 			} else {
459 				psock->rx_skb_nextp =
460 				    &skb_shinfo(head)->frag_list;
461 			}
462 		}
463 	}
464 
465 	while (eaten < orig_len) {
466 		/* Always clone since we will consume something */
467 		skb = skb_clone(orig_skb, GFP_ATOMIC);
468 		if (!skb) {
469 			KCM_STATS_INCR(psock->stats.rx_mem_fail);
470 			desc->error = -ENOMEM;
471 			break;
472 		}
473 
474 		cand_len = orig_len - eaten;
475 
476 		head = psock->rx_skb_head;
477 		if (!head) {
478 			head = skb;
479 			psock->rx_skb_head = head;
480 			/* Will set rx_skb_nextp on next packet if needed */
481 			psock->rx_skb_nextp = NULL;
482 			rxm = kcm_rx_msg(head);
483 			memset(rxm, 0, sizeof(*rxm));
484 			rxm->offset = orig_offset + eaten;
485 		} else {
486 			/* Unclone since we may be appending to an skb that we
487 			 * already share a frag_list with.
488 			 */
489 			err = skb_unclone(skb, GFP_ATOMIC);
490 			if (err) {
491 				KCM_STATS_INCR(psock->stats.rx_mem_fail);
492 				desc->error = err;
493 				break;
494 			}
495 
496 			rxm = kcm_rx_msg(head);
497 			*psock->rx_skb_nextp = skb;
498 			psock->rx_skb_nextp = &skb->next;
499 			head->data_len += skb->len;
500 			head->len += skb->len;
501 			head->truesize += skb->truesize;
502 		}
503 
504 		if (!rxm->full_len) {
505 			ssize_t len;
506 
507 			len = KCM_RUN_FILTER(psock->bpf_prog, head);
508 
509 			if (!len) {
510 				/* Need more header to determine length */
511 				if (!rxm->accum_len) {
512 					/* Start RX timer for new message */
513 					kcm_start_rx_timer(psock);
514 				}
515 				rxm->accum_len += cand_len;
516 				eaten += cand_len;
517 				KCM_STATS_INCR(psock->stats.rx_need_more_hdr);
518 				WARN_ON(eaten != orig_len);
519 				break;
520 			} else if (len > psock->sk->sk_rcvbuf) {
521 				/* Message length exceeds maximum allowed */
522 				KCM_STATS_INCR(psock->stats.rx_msg_too_big);
523 				desc->error = -EMSGSIZE;
524 				psock->rx_skb_head = NULL;
525 				kcm_abort_rx_psock(psock, EMSGSIZE, head);
526 				break;
527 			} else if (len <= (ssize_t)head->len -
528 					  skb->len - rxm->offset) {
529 				/* Length must be into new skb (and also
530 				 * greater than zero)
531 				 */
532 				KCM_STATS_INCR(psock->stats.rx_bad_hdr_len);
533 				desc->error = -EPROTO;
534 				psock->rx_skb_head = NULL;
535 				kcm_abort_rx_psock(psock, EPROTO, head);
536 				break;
537 			}
538 
539 			rxm->full_len = len;
540 		}
541 
542 		extra = (ssize_t)(rxm->accum_len + cand_len) - rxm->full_len;
543 
544 		if (extra < 0) {
545 			/* Message not complete yet. */
546 			if (rxm->full_len - rxm->accum_len >
547 			    tcp_inq(psock->sk)) {
548 				/* Don't have the whole messages in the socket
549 				 * buffer. Set psock->rx_need_bytes to wait for
550 				 * the rest of the message. Also, set "early
551 				 * eaten" since we've already buffered the skb
552 				 * but don't consume yet per tcp_read_sock.
553 				 */
554 
555 				if (!rxm->accum_len) {
556 					/* Start RX timer for new message */
557 					kcm_start_rx_timer(psock);
558 				}
559 
560 				psock->rx_need_bytes = rxm->full_len -
561 						       rxm->accum_len;
562 				rxm->accum_len += cand_len;
563 				rxm->early_eaten = cand_len;
564 				KCM_STATS_ADD(psock->stats.rx_bytes, cand_len);
565 				desc->count = 0; /* Stop reading socket */
566 				break;
567 			}
568 			rxm->accum_len += cand_len;
569 			eaten += cand_len;
570 			WARN_ON(eaten != orig_len);
571 			break;
572 		}
573 
574 		/* Positive extra indicates ore bytes than needed for the
575 		 * message
576 		 */
577 
578 		WARN_ON(extra > cand_len);
579 
580 		eaten += (cand_len - extra);
581 
582 		/* Hurray, we have a new message! */
583 		del_timer(&psock->rx_msg_timer);
584 		psock->rx_skb_head = NULL;
585 		KCM_STATS_INCR(psock->stats.rx_msgs);
586 
587 try_queue:
588 		kcm = reserve_rx_kcm(psock, head);
589 		if (!kcm) {
590 			/* Unable to reserve a KCM, message is held in psock. */
591 			break;
592 		}
593 
594 		if (kcm_queue_rcv_skb(&kcm->sk, head)) {
595 			/* Should mean socket buffer full */
596 			unreserve_rx_kcm(psock, false);
597 			goto try_queue;
598 		}
599 	}
600 
601 	if (cloned_orig)
602 		kfree_skb(orig_skb);
603 
604 	KCM_STATS_ADD(psock->stats.rx_bytes, eaten);
605 
606 	return eaten;
607 }
608 
609 /* Called with lock held on lower socket */
610 static int psock_tcp_read_sock(struct kcm_psock *psock)
611 {
612 	read_descriptor_t desc;
613 
614 	desc.arg.data = psock;
615 	desc.error = 0;
616 	desc.count = 1; /* give more than one skb per call */
617 
618 	/* sk should be locked here, so okay to do tcp_read_sock */
619 	tcp_read_sock(psock->sk, &desc, kcm_tcp_recv);
620 
621 	unreserve_rx_kcm(psock, true);
622 
623 	return desc.error;
624 }
625 
626 /* Lower sock lock held */
627 static void psock_tcp_data_ready(struct sock *sk)
628 {
629 	struct kcm_psock *psock;
630 
631 	read_lock_bh(&sk->sk_callback_lock);
632 
633 	psock = (struct kcm_psock *)sk->sk_user_data;
634 	if (unlikely(!psock || psock->rx_stopped))
635 		goto out;
636 
637 	if (psock->ready_rx_msg)
638 		goto out;
639 
640 	if (psock->rx_need_bytes) {
641 		if (tcp_inq(sk) >= psock->rx_need_bytes)
642 			psock->rx_need_bytes = 0;
643 		else
644 			goto out;
645 	}
646 
647 	if (psock_tcp_read_sock(psock) == -ENOMEM)
648 		queue_delayed_work(kcm_wq, &psock->rx_delayed_work, 0);
649 
650 out:
651 	read_unlock_bh(&sk->sk_callback_lock);
652 }
653 
654 static void do_psock_rx_work(struct kcm_psock *psock)
655 {
656 	read_descriptor_t rd_desc;
657 	struct sock *csk = psock->sk;
658 
659 	/* We need the read lock to synchronize with psock_tcp_data_ready. We
660 	 * need the socket lock for calling tcp_read_sock.
661 	 */
662 	lock_sock(csk);
663 	read_lock_bh(&csk->sk_callback_lock);
664 
665 	if (unlikely(csk->sk_user_data != psock))
666 		goto out;
667 
668 	if (unlikely(psock->rx_stopped))
669 		goto out;
670 
671 	if (psock->ready_rx_msg)
672 		goto out;
673 
674 	rd_desc.arg.data = psock;
675 
676 	if (psock_tcp_read_sock(psock) == -ENOMEM)
677 		queue_delayed_work(kcm_wq, &psock->rx_delayed_work, 0);
678 
679 out:
680 	read_unlock_bh(&csk->sk_callback_lock);
681 	release_sock(csk);
682 }
683 
684 static void psock_rx_work(struct work_struct *w)
685 {
686 	do_psock_rx_work(container_of(w, struct kcm_psock, rx_work));
687 }
688 
689 static void psock_rx_delayed_work(struct work_struct *w)
690 {
691 	do_psock_rx_work(container_of(w, struct kcm_psock,
692 				      rx_delayed_work.work));
693 }
694 
695 static void psock_tcp_state_change(struct sock *sk)
696 {
697 	/* TCP only does a POLLIN for a half close. Do a POLLHUP here
698 	 * since application will normally not poll with POLLIN
699 	 * on the TCP sockets.
700 	 */
701 
702 	report_csk_error(sk, EPIPE);
703 }
704 
705 static void psock_tcp_write_space(struct sock *sk)
706 {
707 	struct kcm_psock *psock;
708 	struct kcm_mux *mux;
709 	struct kcm_sock *kcm;
710 
711 	read_lock_bh(&sk->sk_callback_lock);
712 
713 	psock = (struct kcm_psock *)sk->sk_user_data;
714 	if (unlikely(!psock))
715 		goto out;
716 
717 	mux = psock->mux;
718 
719 	spin_lock_bh(&mux->lock);
720 
721 	/* Check if the socket is reserved so someone is waiting for sending. */
722 	kcm = psock->tx_kcm;
723 	if (kcm)
724 		queue_work(kcm_wq, &kcm->tx_work);
725 
726 	spin_unlock_bh(&mux->lock);
727 out:
728 	read_unlock_bh(&sk->sk_callback_lock);
729 }
730 
731 static void unreserve_psock(struct kcm_sock *kcm);
732 
733 /* kcm sock is locked. */
734 static struct kcm_psock *reserve_psock(struct kcm_sock *kcm)
735 {
736 	struct kcm_mux *mux = kcm->mux;
737 	struct kcm_psock *psock;
738 
739 	psock = kcm->tx_psock;
740 
741 	smp_rmb(); /* Must read tx_psock before tx_wait */
742 
743 	if (psock) {
744 		WARN_ON(kcm->tx_wait);
745 		if (unlikely(psock->tx_stopped))
746 			unreserve_psock(kcm);
747 		else
748 			return kcm->tx_psock;
749 	}
750 
751 	spin_lock_bh(&mux->lock);
752 
753 	/* Check again under lock to see if psock was reserved for this
754 	 * psock via psock_unreserve.
755 	 */
756 	psock = kcm->tx_psock;
757 	if (unlikely(psock)) {
758 		WARN_ON(kcm->tx_wait);
759 		spin_unlock_bh(&mux->lock);
760 		return kcm->tx_psock;
761 	}
762 
763 	if (!list_empty(&mux->psocks_avail)) {
764 		psock = list_first_entry(&mux->psocks_avail,
765 					 struct kcm_psock,
766 					 psock_avail_list);
767 		list_del(&psock->psock_avail_list);
768 		if (kcm->tx_wait) {
769 			list_del(&kcm->wait_psock_list);
770 			kcm->tx_wait = false;
771 		}
772 		kcm->tx_psock = psock;
773 		psock->tx_kcm = kcm;
774 		KCM_STATS_INCR(psock->stats.reserved);
775 	} else if (!kcm->tx_wait) {
776 		list_add_tail(&kcm->wait_psock_list,
777 			      &mux->kcm_tx_waiters);
778 		kcm->tx_wait = true;
779 	}
780 
781 	spin_unlock_bh(&mux->lock);
782 
783 	return psock;
784 }
785 
786 /* mux lock held */
787 static void psock_now_avail(struct kcm_psock *psock)
788 {
789 	struct kcm_mux *mux = psock->mux;
790 	struct kcm_sock *kcm;
791 
792 	if (list_empty(&mux->kcm_tx_waiters)) {
793 		list_add_tail(&psock->psock_avail_list,
794 			      &mux->psocks_avail);
795 	} else {
796 		kcm = list_first_entry(&mux->kcm_tx_waiters,
797 				       struct kcm_sock,
798 				       wait_psock_list);
799 		list_del(&kcm->wait_psock_list);
800 		kcm->tx_wait = false;
801 		psock->tx_kcm = kcm;
802 
803 		/* Commit before changing tx_psock since that is read in
804 		 * reserve_psock before queuing work.
805 		 */
806 		smp_mb();
807 
808 		kcm->tx_psock = psock;
809 		KCM_STATS_INCR(psock->stats.reserved);
810 		queue_work(kcm_wq, &kcm->tx_work);
811 	}
812 }
813 
814 /* kcm sock is locked. */
815 static void unreserve_psock(struct kcm_sock *kcm)
816 {
817 	struct kcm_psock *psock;
818 	struct kcm_mux *mux = kcm->mux;
819 
820 	spin_lock_bh(&mux->lock);
821 
822 	psock = kcm->tx_psock;
823 
824 	if (WARN_ON(!psock)) {
825 		spin_unlock_bh(&mux->lock);
826 		return;
827 	}
828 
829 	smp_rmb(); /* Read tx_psock before tx_wait */
830 
831 	kcm_update_tx_mux_stats(mux, psock);
832 
833 	WARN_ON(kcm->tx_wait);
834 
835 	kcm->tx_psock = NULL;
836 	psock->tx_kcm = NULL;
837 	KCM_STATS_INCR(psock->stats.unreserved);
838 
839 	if (unlikely(psock->tx_stopped)) {
840 		if (psock->done) {
841 			/* Deferred free */
842 			list_del(&psock->psock_list);
843 			mux->psocks_cnt--;
844 			sock_put(psock->sk);
845 			fput(psock->sk->sk_socket->file);
846 			kmem_cache_free(kcm_psockp, psock);
847 		}
848 
849 		/* Don't put back on available list */
850 
851 		spin_unlock_bh(&mux->lock);
852 
853 		return;
854 	}
855 
856 	psock_now_avail(psock);
857 
858 	spin_unlock_bh(&mux->lock);
859 }
860 
861 static void kcm_report_tx_retry(struct kcm_sock *kcm)
862 {
863 	struct kcm_mux *mux = kcm->mux;
864 
865 	spin_lock_bh(&mux->lock);
866 	KCM_STATS_INCR(mux->stats.tx_retries);
867 	spin_unlock_bh(&mux->lock);
868 }
869 
870 /* Write any messages ready on the kcm socket.  Called with kcm sock lock
871  * held.  Return bytes actually sent or error.
872  */
873 static int kcm_write_msgs(struct kcm_sock *kcm)
874 {
875 	struct sock *sk = &kcm->sk;
876 	struct kcm_psock *psock;
877 	struct sk_buff *skb, *head;
878 	struct kcm_tx_msg *txm;
879 	unsigned short fragidx, frag_offset;
880 	unsigned int sent, total_sent = 0;
881 	int ret = 0;
882 
883 	kcm->tx_wait_more = false;
884 	psock = kcm->tx_psock;
885 	if (unlikely(psock && psock->tx_stopped)) {
886 		/* A reserved psock was aborted asynchronously. Unreserve
887 		 * it and we'll retry the message.
888 		 */
889 		unreserve_psock(kcm);
890 		kcm_report_tx_retry(kcm);
891 		if (skb_queue_empty(&sk->sk_write_queue))
892 			return 0;
893 
894 		kcm_tx_msg(skb_peek(&sk->sk_write_queue))->sent = 0;
895 
896 	} else if (skb_queue_empty(&sk->sk_write_queue)) {
897 		return 0;
898 	}
899 
900 	head = skb_peek(&sk->sk_write_queue);
901 	txm = kcm_tx_msg(head);
902 
903 	if (txm->sent) {
904 		/* Send of first skbuff in queue already in progress */
905 		if (WARN_ON(!psock)) {
906 			ret = -EINVAL;
907 			goto out;
908 		}
909 		sent = txm->sent;
910 		frag_offset = txm->frag_offset;
911 		fragidx = txm->fragidx;
912 		skb = txm->frag_skb;
913 
914 		goto do_frag;
915 	}
916 
917 try_again:
918 	psock = reserve_psock(kcm);
919 	if (!psock)
920 		goto out;
921 
922 	do {
923 		skb = head;
924 		txm = kcm_tx_msg(head);
925 		sent = 0;
926 
927 do_frag_list:
928 		if (WARN_ON(!skb_shinfo(skb)->nr_frags)) {
929 			ret = -EINVAL;
930 			goto out;
931 		}
932 
933 		for (fragidx = 0; fragidx < skb_shinfo(skb)->nr_frags;
934 		     fragidx++) {
935 			skb_frag_t *frag;
936 
937 			frag_offset = 0;
938 do_frag:
939 			frag = &skb_shinfo(skb)->frags[fragidx];
940 			if (WARN_ON(!frag->size)) {
941 				ret = -EINVAL;
942 				goto out;
943 			}
944 
945 			ret = kernel_sendpage(psock->sk->sk_socket,
946 					      frag->page.p,
947 					      frag->page_offset + frag_offset,
948 					      frag->size - frag_offset,
949 					      MSG_DONTWAIT);
950 			if (ret <= 0) {
951 				if (ret == -EAGAIN) {
952 					/* Save state to try again when there's
953 					 * write space on the socket
954 					 */
955 					txm->sent = sent;
956 					txm->frag_offset = frag_offset;
957 					txm->fragidx = fragidx;
958 					txm->frag_skb = skb;
959 
960 					ret = 0;
961 					goto out;
962 				}
963 
964 				/* Hard failure in sending message, abort this
965 				 * psock since it has lost framing
966 				 * synchonization and retry sending the
967 				 * message from the beginning.
968 				 */
969 				kcm_abort_tx_psock(psock, ret ? -ret : EPIPE,
970 						   true);
971 				unreserve_psock(kcm);
972 
973 				txm->sent = 0;
974 				kcm_report_tx_retry(kcm);
975 				ret = 0;
976 
977 				goto try_again;
978 			}
979 
980 			sent += ret;
981 			frag_offset += ret;
982 			KCM_STATS_ADD(psock->stats.tx_bytes, ret);
983 			if (frag_offset < frag->size) {
984 				/* Not finished with this frag */
985 				goto do_frag;
986 			}
987 		}
988 
989 		if (skb == head) {
990 			if (skb_has_frag_list(skb)) {
991 				skb = skb_shinfo(skb)->frag_list;
992 				goto do_frag_list;
993 			}
994 		} else if (skb->next) {
995 			skb = skb->next;
996 			goto do_frag_list;
997 		}
998 
999 		/* Successfully sent the whole packet, account for it. */
1000 		skb_dequeue(&sk->sk_write_queue);
1001 		kfree_skb(head);
1002 		sk->sk_wmem_queued -= sent;
1003 		total_sent += sent;
1004 		KCM_STATS_INCR(psock->stats.tx_msgs);
1005 	} while ((head = skb_peek(&sk->sk_write_queue)));
1006 out:
1007 	if (!head) {
1008 		/* Done with all queued messages. */
1009 		WARN_ON(!skb_queue_empty(&sk->sk_write_queue));
1010 		unreserve_psock(kcm);
1011 	}
1012 
1013 	/* Check if write space is available */
1014 	sk->sk_write_space(sk);
1015 
1016 	return total_sent ? : ret;
1017 }
1018 
1019 static void kcm_tx_work(struct work_struct *w)
1020 {
1021 	struct kcm_sock *kcm = container_of(w, struct kcm_sock, tx_work);
1022 	struct sock *sk = &kcm->sk;
1023 	int err;
1024 
1025 	lock_sock(sk);
1026 
1027 	/* Primarily for SOCK_DGRAM sockets, also handle asynchronous tx
1028 	 * aborts
1029 	 */
1030 	err = kcm_write_msgs(kcm);
1031 	if (err < 0) {
1032 		/* Hard failure in write, report error on KCM socket */
1033 		pr_warn("KCM: Hard failure on kcm_write_msgs %d\n", err);
1034 		report_csk_error(&kcm->sk, -err);
1035 		goto out;
1036 	}
1037 
1038 	/* Primarily for SOCK_SEQPACKET sockets */
1039 	if (likely(sk->sk_socket) &&
1040 	    test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) {
1041 		clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1042 		sk->sk_write_space(sk);
1043 	}
1044 
1045 out:
1046 	release_sock(sk);
1047 }
1048 
1049 static void kcm_push(struct kcm_sock *kcm)
1050 {
1051 	if (kcm->tx_wait_more)
1052 		kcm_write_msgs(kcm);
1053 }
1054 
1055 static ssize_t kcm_sendpage(struct socket *sock, struct page *page,
1056 			    int offset, size_t size, int flags)
1057 
1058 {
1059 	struct sock *sk = sock->sk;
1060 	struct kcm_sock *kcm = kcm_sk(sk);
1061 	struct sk_buff *skb = NULL, *head = NULL;
1062 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1063 	bool eor;
1064 	int err = 0;
1065 	int i;
1066 
1067 	if (flags & MSG_SENDPAGE_NOTLAST)
1068 		flags |= MSG_MORE;
1069 
1070 	/* No MSG_EOR from splice, only look at MSG_MORE */
1071 	eor = !(flags & MSG_MORE);
1072 
1073 	lock_sock(sk);
1074 
1075 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1076 
1077 	err = -EPIPE;
1078 	if (sk->sk_err)
1079 		goto out_error;
1080 
1081 	if (kcm->seq_skb) {
1082 		/* Previously opened message */
1083 		head = kcm->seq_skb;
1084 		skb = kcm_tx_msg(head)->last_skb;
1085 		i = skb_shinfo(skb)->nr_frags;
1086 
1087 		if (skb_can_coalesce(skb, i, page, offset)) {
1088 			skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], size);
1089 			skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
1090 			goto coalesced;
1091 		}
1092 
1093 		if (i >= MAX_SKB_FRAGS) {
1094 			struct sk_buff *tskb;
1095 
1096 			tskb = alloc_skb(0, sk->sk_allocation);
1097 			while (!tskb) {
1098 				kcm_push(kcm);
1099 				err = sk_stream_wait_memory(sk, &timeo);
1100 				if (err)
1101 					goto out_error;
1102 			}
1103 
1104 			if (head == skb)
1105 				skb_shinfo(head)->frag_list = tskb;
1106 			else
1107 				skb->next = tskb;
1108 
1109 			skb = tskb;
1110 			skb->ip_summed = CHECKSUM_UNNECESSARY;
1111 			i = 0;
1112 		}
1113 	} else {
1114 		/* Call the sk_stream functions to manage the sndbuf mem. */
1115 		if (!sk_stream_memory_free(sk)) {
1116 			kcm_push(kcm);
1117 			set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1118 			err = sk_stream_wait_memory(sk, &timeo);
1119 			if (err)
1120 				goto out_error;
1121 		}
1122 
1123 		head = alloc_skb(0, sk->sk_allocation);
1124 		while (!head) {
1125 			kcm_push(kcm);
1126 			err = sk_stream_wait_memory(sk, &timeo);
1127 			if (err)
1128 				goto out_error;
1129 		}
1130 
1131 		skb = head;
1132 		i = 0;
1133 	}
1134 
1135 	get_page(page);
1136 	skb_fill_page_desc(skb, i, page, offset, size);
1137 	skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
1138 
1139 coalesced:
1140 	skb->len += size;
1141 	skb->data_len += size;
1142 	skb->truesize += size;
1143 	sk->sk_wmem_queued += size;
1144 	sk_mem_charge(sk, size);
1145 
1146 	if (head != skb) {
1147 		head->len += size;
1148 		head->data_len += size;
1149 		head->truesize += size;
1150 	}
1151 
1152 	if (eor) {
1153 		bool not_busy = skb_queue_empty(&sk->sk_write_queue);
1154 
1155 		/* Message complete, queue it on send buffer */
1156 		__skb_queue_tail(&sk->sk_write_queue, head);
1157 		kcm->seq_skb = NULL;
1158 		KCM_STATS_INCR(kcm->stats.tx_msgs);
1159 
1160 		if (flags & MSG_BATCH) {
1161 			kcm->tx_wait_more = true;
1162 		} else if (kcm->tx_wait_more || not_busy) {
1163 			err = kcm_write_msgs(kcm);
1164 			if (err < 0) {
1165 				/* We got a hard error in write_msgs but have
1166 				 * already queued this message. Report an error
1167 				 * in the socket, but don't affect return value
1168 				 * from sendmsg
1169 				 */
1170 				pr_warn("KCM: Hard failure on kcm_write_msgs\n");
1171 				report_csk_error(&kcm->sk, -err);
1172 			}
1173 		}
1174 	} else {
1175 		/* Message not complete, save state */
1176 		kcm->seq_skb = head;
1177 		kcm_tx_msg(head)->last_skb = skb;
1178 	}
1179 
1180 	KCM_STATS_ADD(kcm->stats.tx_bytes, size);
1181 
1182 	release_sock(sk);
1183 	return size;
1184 
1185 out_error:
1186 	kcm_push(kcm);
1187 
1188 	err = sk_stream_error(sk, flags, err);
1189 
1190 	/* make sure we wake any epoll edge trigger waiter */
1191 	if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
1192 		sk->sk_write_space(sk);
1193 
1194 	release_sock(sk);
1195 	return err;
1196 }
1197 
1198 static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
1199 {
1200 	struct sock *sk = sock->sk;
1201 	struct kcm_sock *kcm = kcm_sk(sk);
1202 	struct sk_buff *skb = NULL, *head = NULL;
1203 	size_t copy, copied = 0;
1204 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1205 	int eor = (sock->type == SOCK_DGRAM) ?
1206 		  !(msg->msg_flags & MSG_MORE) : !!(msg->msg_flags & MSG_EOR);
1207 	int err = -EPIPE;
1208 
1209 	lock_sock(sk);
1210 
1211 	/* Per tcp_sendmsg this should be in poll */
1212 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1213 
1214 	if (sk->sk_err)
1215 		goto out_error;
1216 
1217 	if (kcm->seq_skb) {
1218 		/* Previously opened message */
1219 		head = kcm->seq_skb;
1220 		skb = kcm_tx_msg(head)->last_skb;
1221 		goto start;
1222 	}
1223 
1224 	/* Call the sk_stream functions to manage the sndbuf mem. */
1225 	if (!sk_stream_memory_free(sk)) {
1226 		kcm_push(kcm);
1227 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1228 		err = sk_stream_wait_memory(sk, &timeo);
1229 		if (err)
1230 			goto out_error;
1231 	}
1232 
1233 	/* New message, alloc head skb */
1234 	head = alloc_skb(0, sk->sk_allocation);
1235 	while (!head) {
1236 		kcm_push(kcm);
1237 		err = sk_stream_wait_memory(sk, &timeo);
1238 		if (err)
1239 			goto out_error;
1240 
1241 		head = alloc_skb(0, sk->sk_allocation);
1242 	}
1243 
1244 	skb = head;
1245 
1246 	/* Set ip_summed to CHECKSUM_UNNECESSARY to avoid calling
1247 	 * csum_and_copy_from_iter from skb_do_copy_data_nocache.
1248 	 */
1249 	skb->ip_summed = CHECKSUM_UNNECESSARY;
1250 
1251 start:
1252 	while (msg_data_left(msg)) {
1253 		bool merge = true;
1254 		int i = skb_shinfo(skb)->nr_frags;
1255 		struct page_frag *pfrag = sk_page_frag(sk);
1256 
1257 		if (!sk_page_frag_refill(sk, pfrag))
1258 			goto wait_for_memory;
1259 
1260 		if (!skb_can_coalesce(skb, i, pfrag->page,
1261 				      pfrag->offset)) {
1262 			if (i == MAX_SKB_FRAGS) {
1263 				struct sk_buff *tskb;
1264 
1265 				tskb = alloc_skb(0, sk->sk_allocation);
1266 				if (!tskb)
1267 					goto wait_for_memory;
1268 
1269 				if (head == skb)
1270 					skb_shinfo(head)->frag_list = tskb;
1271 				else
1272 					skb->next = tskb;
1273 
1274 				skb = tskb;
1275 				skb->ip_summed = CHECKSUM_UNNECESSARY;
1276 				continue;
1277 			}
1278 			merge = false;
1279 		}
1280 
1281 		copy = min_t(int, msg_data_left(msg),
1282 			     pfrag->size - pfrag->offset);
1283 
1284 		if (!sk_wmem_schedule(sk, copy))
1285 			goto wait_for_memory;
1286 
1287 		err = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb,
1288 					       pfrag->page,
1289 					       pfrag->offset,
1290 					       copy);
1291 		if (err)
1292 			goto out_error;
1293 
1294 		/* Update the skb. */
1295 		if (merge) {
1296 			skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
1297 		} else {
1298 			skb_fill_page_desc(skb, i, pfrag->page,
1299 					   pfrag->offset, copy);
1300 			get_page(pfrag->page);
1301 		}
1302 
1303 		pfrag->offset += copy;
1304 		copied += copy;
1305 		if (head != skb) {
1306 			head->len += copy;
1307 			head->data_len += copy;
1308 		}
1309 
1310 		continue;
1311 
1312 wait_for_memory:
1313 		kcm_push(kcm);
1314 		err = sk_stream_wait_memory(sk, &timeo);
1315 		if (err)
1316 			goto out_error;
1317 	}
1318 
1319 	if (eor) {
1320 		bool not_busy = skb_queue_empty(&sk->sk_write_queue);
1321 
1322 		/* Message complete, queue it on send buffer */
1323 		__skb_queue_tail(&sk->sk_write_queue, head);
1324 		kcm->seq_skb = NULL;
1325 		KCM_STATS_INCR(kcm->stats.tx_msgs);
1326 
1327 		if (msg->msg_flags & MSG_BATCH) {
1328 			kcm->tx_wait_more = true;
1329 		} else if (kcm->tx_wait_more || not_busy) {
1330 			err = kcm_write_msgs(kcm);
1331 			if (err < 0) {
1332 				/* We got a hard error in write_msgs but have
1333 				 * already queued this message. Report an error
1334 				 * in the socket, but don't affect return value
1335 				 * from sendmsg
1336 				 */
1337 				pr_warn("KCM: Hard failure on kcm_write_msgs\n");
1338 				report_csk_error(&kcm->sk, -err);
1339 			}
1340 		}
1341 	} else {
1342 		/* Message not complete, save state */
1343 partial_message:
1344 		kcm->seq_skb = head;
1345 		kcm_tx_msg(head)->last_skb = skb;
1346 	}
1347 
1348 	KCM_STATS_ADD(kcm->stats.tx_bytes, copied);
1349 
1350 	release_sock(sk);
1351 	return copied;
1352 
1353 out_error:
1354 	kcm_push(kcm);
1355 
1356 	if (copied && sock->type == SOCK_SEQPACKET) {
1357 		/* Wrote some bytes before encountering an
1358 		 * error, return partial success.
1359 		 */
1360 		goto partial_message;
1361 	}
1362 
1363 	if (head != kcm->seq_skb)
1364 		kfree_skb(head);
1365 
1366 	err = sk_stream_error(sk, msg->msg_flags, err);
1367 
1368 	/* make sure we wake any epoll edge trigger waiter */
1369 	if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
1370 		sk->sk_write_space(sk);
1371 
1372 	release_sock(sk);
1373 	return err;
1374 }
1375 
1376 static struct sk_buff *kcm_wait_data(struct sock *sk, int flags,
1377 				     long timeo, int *err)
1378 {
1379 	struct sk_buff *skb;
1380 
1381 	while (!(skb = skb_peek(&sk->sk_receive_queue))) {
1382 		if (sk->sk_err) {
1383 			*err = sock_error(sk);
1384 			return NULL;
1385 		}
1386 
1387 		if (sock_flag(sk, SOCK_DONE))
1388 			return NULL;
1389 
1390 		if ((flags & MSG_DONTWAIT) || !timeo) {
1391 			*err = -EAGAIN;
1392 			return NULL;
1393 		}
1394 
1395 		sk_wait_data(sk, &timeo, NULL);
1396 
1397 		/* Handle signals */
1398 		if (signal_pending(current)) {
1399 			*err = sock_intr_errno(timeo);
1400 			return NULL;
1401 		}
1402 	}
1403 
1404 	return skb;
1405 }
1406 
1407 static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
1408 		       size_t len, int flags)
1409 {
1410 	struct sock *sk = sock->sk;
1411 	struct kcm_sock *kcm = kcm_sk(sk);
1412 	int err = 0;
1413 	long timeo;
1414 	struct kcm_rx_msg *rxm;
1415 	int copied = 0;
1416 	struct sk_buff *skb;
1417 
1418 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1419 
1420 	lock_sock(sk);
1421 
1422 	skb = kcm_wait_data(sk, flags, timeo, &err);
1423 	if (!skb)
1424 		goto out;
1425 
1426 	/* Okay, have a message on the receive queue */
1427 
1428 	rxm = kcm_rx_msg(skb);
1429 
1430 	if (len > rxm->full_len)
1431 		len = rxm->full_len;
1432 
1433 	err = skb_copy_datagram_msg(skb, rxm->offset, msg, len);
1434 	if (err < 0)
1435 		goto out;
1436 
1437 	copied = len;
1438 	if (likely(!(flags & MSG_PEEK))) {
1439 		KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
1440 		if (copied < rxm->full_len) {
1441 			if (sock->type == SOCK_DGRAM) {
1442 				/* Truncated message */
1443 				msg->msg_flags |= MSG_TRUNC;
1444 				goto msg_finished;
1445 			}
1446 			rxm->offset += copied;
1447 			rxm->full_len -= copied;
1448 		} else {
1449 msg_finished:
1450 			/* Finished with message */
1451 			msg->msg_flags |= MSG_EOR;
1452 			KCM_STATS_INCR(kcm->stats.rx_msgs);
1453 			skb_unlink(skb, &sk->sk_receive_queue);
1454 			kfree_skb(skb);
1455 		}
1456 	}
1457 
1458 out:
1459 	release_sock(sk);
1460 
1461 	return copied ? : err;
1462 }
1463 
1464 static ssize_t kcm_sock_splice(struct sock *sk,
1465 			       struct pipe_inode_info *pipe,
1466 			       struct splice_pipe_desc *spd)
1467 {
1468 	int ret;
1469 
1470 	release_sock(sk);
1471 	ret = splice_to_pipe(pipe, spd);
1472 	lock_sock(sk);
1473 
1474 	return ret;
1475 }
1476 
1477 static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
1478 			       struct pipe_inode_info *pipe, size_t len,
1479 			       unsigned int flags)
1480 {
1481 	struct sock *sk = sock->sk;
1482 	struct kcm_sock *kcm = kcm_sk(sk);
1483 	long timeo;
1484 	struct kcm_rx_msg *rxm;
1485 	int err = 0;
1486 	size_t copied;
1487 	struct sk_buff *skb;
1488 
1489 	/* Only support splice for SOCKSEQPACKET */
1490 
1491 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1492 
1493 	lock_sock(sk);
1494 
1495 	skb = kcm_wait_data(sk, flags, timeo, &err);
1496 	if (!skb)
1497 		goto err_out;
1498 
1499 	/* Okay, have a message on the receive queue */
1500 
1501 	rxm = kcm_rx_msg(skb);
1502 
1503 	if (len > rxm->full_len)
1504 		len = rxm->full_len;
1505 
1506 	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, len, flags,
1507 				 kcm_sock_splice);
1508 	if (copied < 0) {
1509 		err = copied;
1510 		goto err_out;
1511 	}
1512 
1513 	KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
1514 
1515 	rxm->offset += copied;
1516 	rxm->full_len -= copied;
1517 
1518 	/* We have no way to return MSG_EOR. If all the bytes have been
1519 	 * read we still leave the message in the receive socket buffer.
1520 	 * A subsequent recvmsg needs to be done to return MSG_EOR and
1521 	 * finish reading the message.
1522 	 */
1523 
1524 	release_sock(sk);
1525 
1526 	return copied;
1527 
1528 err_out:
1529 	release_sock(sk);
1530 
1531 	return err;
1532 }
1533 
1534 /* kcm sock lock held */
1535 static void kcm_recv_disable(struct kcm_sock *kcm)
1536 {
1537 	struct kcm_mux *mux = kcm->mux;
1538 
1539 	if (kcm->rx_disabled)
1540 		return;
1541 
1542 	spin_lock_bh(&mux->rx_lock);
1543 
1544 	kcm->rx_disabled = 1;
1545 
1546 	/* If a psock is reserved we'll do cleanup in unreserve */
1547 	if (!kcm->rx_psock) {
1548 		if (kcm->rx_wait) {
1549 			list_del(&kcm->wait_rx_list);
1550 			kcm->rx_wait = false;
1551 		}
1552 
1553 		requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue);
1554 	}
1555 
1556 	spin_unlock_bh(&mux->rx_lock);
1557 }
1558 
1559 /* kcm sock lock held */
1560 static void kcm_recv_enable(struct kcm_sock *kcm)
1561 {
1562 	struct kcm_mux *mux = kcm->mux;
1563 
1564 	if (!kcm->rx_disabled)
1565 		return;
1566 
1567 	spin_lock_bh(&mux->rx_lock);
1568 
1569 	kcm->rx_disabled = 0;
1570 	kcm_rcv_ready(kcm);
1571 
1572 	spin_unlock_bh(&mux->rx_lock);
1573 }
1574 
1575 static int kcm_setsockopt(struct socket *sock, int level, int optname,
1576 			  char __user *optval, unsigned int optlen)
1577 {
1578 	struct kcm_sock *kcm = kcm_sk(sock->sk);
1579 	int val, valbool;
1580 	int err = 0;
1581 
1582 	if (level != SOL_KCM)
1583 		return -ENOPROTOOPT;
1584 
1585 	if (optlen < sizeof(int))
1586 		return -EINVAL;
1587 
1588 	if (get_user(val, (int __user *)optval))
1589 		return -EINVAL;
1590 
1591 	valbool = val ? 1 : 0;
1592 
1593 	switch (optname) {
1594 	case KCM_RECV_DISABLE:
1595 		lock_sock(&kcm->sk);
1596 		if (valbool)
1597 			kcm_recv_disable(kcm);
1598 		else
1599 			kcm_recv_enable(kcm);
1600 		release_sock(&kcm->sk);
1601 		break;
1602 	default:
1603 		err = -ENOPROTOOPT;
1604 	}
1605 
1606 	return err;
1607 }
1608 
1609 static int kcm_getsockopt(struct socket *sock, int level, int optname,
1610 			  char __user *optval, int __user *optlen)
1611 {
1612 	struct kcm_sock *kcm = kcm_sk(sock->sk);
1613 	int val, len;
1614 
1615 	if (level != SOL_KCM)
1616 		return -ENOPROTOOPT;
1617 
1618 	if (get_user(len, optlen))
1619 		return -EFAULT;
1620 
1621 	len = min_t(unsigned int, len, sizeof(int));
1622 	if (len < 0)
1623 		return -EINVAL;
1624 
1625 	switch (optname) {
1626 	case KCM_RECV_DISABLE:
1627 		val = kcm->rx_disabled;
1628 		break;
1629 	default:
1630 		return -ENOPROTOOPT;
1631 	}
1632 
1633 	if (put_user(len, optlen))
1634 		return -EFAULT;
1635 	if (copy_to_user(optval, &val, len))
1636 		return -EFAULT;
1637 	return 0;
1638 }
1639 
1640 static void init_kcm_sock(struct kcm_sock *kcm, struct kcm_mux *mux)
1641 {
1642 	struct kcm_sock *tkcm;
1643 	struct list_head *head;
1644 	int index = 0;
1645 
1646 	/* For SOCK_SEQPACKET sock type, datagram_poll checks the sk_state, so
1647 	 * we set sk_state, otherwise epoll_wait always returns right away with
1648 	 * POLLHUP
1649 	 */
1650 	kcm->sk.sk_state = TCP_ESTABLISHED;
1651 
1652 	/* Add to mux's kcm sockets list */
1653 	kcm->mux = mux;
1654 	spin_lock_bh(&mux->lock);
1655 
1656 	head = &mux->kcm_socks;
1657 	list_for_each_entry(tkcm, &mux->kcm_socks, kcm_sock_list) {
1658 		if (tkcm->index != index)
1659 			break;
1660 		head = &tkcm->kcm_sock_list;
1661 		index++;
1662 	}
1663 
1664 	list_add(&kcm->kcm_sock_list, head);
1665 	kcm->index = index;
1666 
1667 	mux->kcm_socks_cnt++;
1668 	spin_unlock_bh(&mux->lock);
1669 
1670 	INIT_WORK(&kcm->tx_work, kcm_tx_work);
1671 
1672 	spin_lock_bh(&mux->rx_lock);
1673 	kcm_rcv_ready(kcm);
1674 	spin_unlock_bh(&mux->rx_lock);
1675 }
1676 
1677 static void kcm_rx_msg_timeout(unsigned long arg)
1678 {
1679 	struct kcm_psock *psock = (struct kcm_psock *)arg;
1680 
1681 	/* Message assembly timed out */
1682 	KCM_STATS_INCR(psock->stats.rx_msg_timeouts);
1683 	kcm_abort_rx_psock(psock, ETIMEDOUT, NULL);
1684 }
1685 
1686 static int kcm_attach(struct socket *sock, struct socket *csock,
1687 		      struct bpf_prog *prog)
1688 {
1689 	struct kcm_sock *kcm = kcm_sk(sock->sk);
1690 	struct kcm_mux *mux = kcm->mux;
1691 	struct sock *csk;
1692 	struct kcm_psock *psock = NULL, *tpsock;
1693 	struct list_head *head;
1694 	int index = 0;
1695 
1696 	if (csock->ops->family != PF_INET &&
1697 	    csock->ops->family != PF_INET6)
1698 		return -EINVAL;
1699 
1700 	csk = csock->sk;
1701 	if (!csk)
1702 		return -EINVAL;
1703 
1704 	/* Only support TCP for now */
1705 	if (csk->sk_protocol != IPPROTO_TCP)
1706 		return -EINVAL;
1707 
1708 	psock = kmem_cache_zalloc(kcm_psockp, GFP_KERNEL);
1709 	if (!psock)
1710 		return -ENOMEM;
1711 
1712 	psock->mux = mux;
1713 	psock->sk = csk;
1714 	psock->bpf_prog = prog;
1715 
1716 	setup_timer(&psock->rx_msg_timer, kcm_rx_msg_timeout,
1717 		    (unsigned long)psock);
1718 
1719 	INIT_WORK(&psock->rx_work, psock_rx_work);
1720 	INIT_DELAYED_WORK(&psock->rx_delayed_work, psock_rx_delayed_work);
1721 
1722 	sock_hold(csk);
1723 
1724 	write_lock_bh(&csk->sk_callback_lock);
1725 	psock->save_data_ready = csk->sk_data_ready;
1726 	psock->save_write_space = csk->sk_write_space;
1727 	psock->save_state_change = csk->sk_state_change;
1728 	csk->sk_user_data = psock;
1729 	csk->sk_data_ready = psock_tcp_data_ready;
1730 	csk->sk_write_space = psock_tcp_write_space;
1731 	csk->sk_state_change = psock_tcp_state_change;
1732 	write_unlock_bh(&csk->sk_callback_lock);
1733 
1734 	/* Finished initialization, now add the psock to the MUX. */
1735 	spin_lock_bh(&mux->lock);
1736 	head = &mux->psocks;
1737 	list_for_each_entry(tpsock, &mux->psocks, psock_list) {
1738 		if (tpsock->index != index)
1739 			break;
1740 		head = &tpsock->psock_list;
1741 		index++;
1742 	}
1743 
1744 	list_add(&psock->psock_list, head);
1745 	psock->index = index;
1746 
1747 	KCM_STATS_INCR(mux->stats.psock_attach);
1748 	mux->psocks_cnt++;
1749 	psock_now_avail(psock);
1750 	spin_unlock_bh(&mux->lock);
1751 
1752 	/* Schedule RX work in case there are already bytes queued */
1753 	queue_work(kcm_wq, &psock->rx_work);
1754 
1755 	return 0;
1756 }
1757 
1758 static int kcm_attach_ioctl(struct socket *sock, struct kcm_attach *info)
1759 {
1760 	struct socket *csock;
1761 	struct bpf_prog *prog;
1762 	int err;
1763 
1764 	csock = sockfd_lookup(info->fd, &err);
1765 	if (!csock)
1766 		return -ENOENT;
1767 
1768 	prog = bpf_prog_get(info->bpf_fd);
1769 	if (IS_ERR(prog)) {
1770 		err = PTR_ERR(prog);
1771 		goto out;
1772 	}
1773 
1774 	if (prog->type != BPF_PROG_TYPE_SOCKET_FILTER) {
1775 		bpf_prog_put(prog);
1776 		err = -EINVAL;
1777 		goto out;
1778 	}
1779 
1780 	err = kcm_attach(sock, csock, prog);
1781 	if (err) {
1782 		bpf_prog_put(prog);
1783 		goto out;
1784 	}
1785 
1786 	/* Keep reference on file also */
1787 
1788 	return 0;
1789 out:
1790 	fput(csock->file);
1791 	return err;
1792 }
1793 
1794 static void kcm_unattach(struct kcm_psock *psock)
1795 {
1796 	struct sock *csk = psock->sk;
1797 	struct kcm_mux *mux = psock->mux;
1798 
1799 	/* Stop getting callbacks from TCP socket. After this there should
1800 	 * be no way to reserve a kcm for this psock.
1801 	 */
1802 	write_lock_bh(&csk->sk_callback_lock);
1803 	csk->sk_user_data = NULL;
1804 	csk->sk_data_ready = psock->save_data_ready;
1805 	csk->sk_write_space = psock->save_write_space;
1806 	csk->sk_state_change = psock->save_state_change;
1807 	psock->rx_stopped = 1;
1808 
1809 	if (WARN_ON(psock->rx_kcm)) {
1810 		write_unlock_bh(&csk->sk_callback_lock);
1811 		return;
1812 	}
1813 
1814 	spin_lock_bh(&mux->rx_lock);
1815 
1816 	/* Stop receiver activities. After this point psock should not be
1817 	 * able to get onto ready list either through callbacks or work.
1818 	 */
1819 	if (psock->ready_rx_msg) {
1820 		list_del(&psock->psock_ready_list);
1821 		kfree_skb(psock->ready_rx_msg);
1822 		psock->ready_rx_msg = NULL;
1823 		KCM_STATS_INCR(mux->stats.rx_ready_drops);
1824 	}
1825 
1826 	spin_unlock_bh(&mux->rx_lock);
1827 
1828 	write_unlock_bh(&csk->sk_callback_lock);
1829 
1830 	del_timer_sync(&psock->rx_msg_timer);
1831 	cancel_work_sync(&psock->rx_work);
1832 	cancel_delayed_work_sync(&psock->rx_delayed_work);
1833 
1834 	bpf_prog_put(psock->bpf_prog);
1835 
1836 	kfree_skb(psock->rx_skb_head);
1837 	psock->rx_skb_head = NULL;
1838 
1839 	spin_lock_bh(&mux->lock);
1840 
1841 	aggregate_psock_stats(&psock->stats, &mux->aggregate_psock_stats);
1842 
1843 	KCM_STATS_INCR(mux->stats.psock_unattach);
1844 
1845 	if (psock->tx_kcm) {
1846 		/* psock was reserved.  Just mark it finished and we will clean
1847 		 * up in the kcm paths, we need kcm lock which can not be
1848 		 * acquired here.
1849 		 */
1850 		KCM_STATS_INCR(mux->stats.psock_unattach_rsvd);
1851 		spin_unlock_bh(&mux->lock);
1852 
1853 		/* We are unattaching a socket that is reserved. Abort the
1854 		 * socket since we may be out of sync in sending on it. We need
1855 		 * to do this without the mux lock.
1856 		 */
1857 		kcm_abort_tx_psock(psock, EPIPE, false);
1858 
1859 		spin_lock_bh(&mux->lock);
1860 		if (!psock->tx_kcm) {
1861 			/* psock now unreserved in window mux was unlocked */
1862 			goto no_reserved;
1863 		}
1864 		psock->done = 1;
1865 
1866 		/* Commit done before queuing work to process it */
1867 		smp_mb();
1868 
1869 		/* Queue tx work to make sure psock->done is handled */
1870 		queue_work(kcm_wq, &psock->tx_kcm->tx_work);
1871 		spin_unlock_bh(&mux->lock);
1872 	} else {
1873 no_reserved:
1874 		if (!psock->tx_stopped)
1875 			list_del(&psock->psock_avail_list);
1876 		list_del(&psock->psock_list);
1877 		mux->psocks_cnt--;
1878 		spin_unlock_bh(&mux->lock);
1879 
1880 		sock_put(csk);
1881 		fput(csk->sk_socket->file);
1882 		kmem_cache_free(kcm_psockp, psock);
1883 	}
1884 }
1885 
1886 static int kcm_unattach_ioctl(struct socket *sock, struct kcm_unattach *info)
1887 {
1888 	struct kcm_sock *kcm = kcm_sk(sock->sk);
1889 	struct kcm_mux *mux = kcm->mux;
1890 	struct kcm_psock *psock;
1891 	struct socket *csock;
1892 	struct sock *csk;
1893 	int err;
1894 
1895 	csock = sockfd_lookup(info->fd, &err);
1896 	if (!csock)
1897 		return -ENOENT;
1898 
1899 	csk = csock->sk;
1900 	if (!csk) {
1901 		err = -EINVAL;
1902 		goto out;
1903 	}
1904 
1905 	err = -ENOENT;
1906 
1907 	spin_lock_bh(&mux->lock);
1908 
1909 	list_for_each_entry(psock, &mux->psocks, psock_list) {
1910 		if (psock->sk != csk)
1911 			continue;
1912 
1913 		/* Found the matching psock */
1914 
1915 		if (psock->unattaching || WARN_ON(psock->done)) {
1916 			err = -EALREADY;
1917 			break;
1918 		}
1919 
1920 		psock->unattaching = 1;
1921 
1922 		spin_unlock_bh(&mux->lock);
1923 
1924 		kcm_unattach(psock);
1925 
1926 		err = 0;
1927 		goto out;
1928 	}
1929 
1930 	spin_unlock_bh(&mux->lock);
1931 
1932 out:
1933 	fput(csock->file);
1934 	return err;
1935 }
1936 
1937 static struct proto kcm_proto = {
1938 	.name	= "KCM",
1939 	.owner	= THIS_MODULE,
1940 	.obj_size = sizeof(struct kcm_sock),
1941 };
1942 
1943 /* Clone a kcm socket. */
1944 static int kcm_clone(struct socket *osock, struct kcm_clone *info,
1945 		     struct socket **newsockp)
1946 {
1947 	struct socket *newsock;
1948 	struct sock *newsk;
1949 	struct file *newfile;
1950 	int err, newfd;
1951 
1952 	err = -ENFILE;
1953 	newsock = sock_alloc();
1954 	if (!newsock)
1955 		goto out;
1956 
1957 	newsock->type = osock->type;
1958 	newsock->ops = osock->ops;
1959 
1960 	__module_get(newsock->ops->owner);
1961 
1962 	newfd = get_unused_fd_flags(0);
1963 	if (unlikely(newfd < 0)) {
1964 		err = newfd;
1965 		goto out_fd_fail;
1966 	}
1967 
1968 	newfile = sock_alloc_file(newsock, 0, osock->sk->sk_prot_creator->name);
1969 	if (unlikely(IS_ERR(newfile))) {
1970 		err = PTR_ERR(newfile);
1971 		goto out_sock_alloc_fail;
1972 	}
1973 
1974 	newsk = sk_alloc(sock_net(osock->sk), PF_KCM, GFP_KERNEL,
1975 			 &kcm_proto, true);
1976 	if (!newsk) {
1977 		err = -ENOMEM;
1978 		goto out_sk_alloc_fail;
1979 	}
1980 
1981 	sock_init_data(newsock, newsk);
1982 	init_kcm_sock(kcm_sk(newsk), kcm_sk(osock->sk)->mux);
1983 
1984 	fd_install(newfd, newfile);
1985 	*newsockp = newsock;
1986 	info->fd = newfd;
1987 
1988 	return 0;
1989 
1990 out_sk_alloc_fail:
1991 	fput(newfile);
1992 out_sock_alloc_fail:
1993 	put_unused_fd(newfd);
1994 out_fd_fail:
1995 	sock_release(newsock);
1996 out:
1997 	return err;
1998 }
1999 
2000 static int kcm_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
2001 {
2002 	int err;
2003 
2004 	switch (cmd) {
2005 	case SIOCKCMATTACH: {
2006 		struct kcm_attach info;
2007 
2008 		if (copy_from_user(&info, (void __user *)arg, sizeof(info)))
2009 			err = -EFAULT;
2010 
2011 		err = kcm_attach_ioctl(sock, &info);
2012 
2013 		break;
2014 	}
2015 	case SIOCKCMUNATTACH: {
2016 		struct kcm_unattach info;
2017 
2018 		if (copy_from_user(&info, (void __user *)arg, sizeof(info)))
2019 			err = -EFAULT;
2020 
2021 		err = kcm_unattach_ioctl(sock, &info);
2022 
2023 		break;
2024 	}
2025 	case SIOCKCMCLONE: {
2026 		struct kcm_clone info;
2027 		struct socket *newsock = NULL;
2028 
2029 		if (copy_from_user(&info, (void __user *)arg, sizeof(info)))
2030 			err = -EFAULT;
2031 
2032 		err = kcm_clone(sock, &info, &newsock);
2033 
2034 		if (!err) {
2035 			if (copy_to_user((void __user *)arg, &info,
2036 					 sizeof(info))) {
2037 				err = -EFAULT;
2038 				sock_release(newsock);
2039 			}
2040 		}
2041 
2042 		break;
2043 	}
2044 	default:
2045 		err = -ENOIOCTLCMD;
2046 		break;
2047 	}
2048 
2049 	return err;
2050 }
2051 
2052 static void free_mux(struct rcu_head *rcu)
2053 {
2054 	struct kcm_mux *mux = container_of(rcu,
2055 	    struct kcm_mux, rcu);
2056 
2057 	kmem_cache_free(kcm_muxp, mux);
2058 }
2059 
2060 static void release_mux(struct kcm_mux *mux)
2061 {
2062 	struct kcm_net *knet = mux->knet;
2063 	struct kcm_psock *psock, *tmp_psock;
2064 
2065 	/* Release psocks */
2066 	list_for_each_entry_safe(psock, tmp_psock,
2067 				 &mux->psocks, psock_list) {
2068 		if (!WARN_ON(psock->unattaching))
2069 			kcm_unattach(psock);
2070 	}
2071 
2072 	if (WARN_ON(mux->psocks_cnt))
2073 		return;
2074 
2075 	__skb_queue_purge(&mux->rx_hold_queue);
2076 
2077 	mutex_lock(&knet->mutex);
2078 	aggregate_mux_stats(&mux->stats, &knet->aggregate_mux_stats);
2079 	aggregate_psock_stats(&mux->aggregate_psock_stats,
2080 			      &knet->aggregate_psock_stats);
2081 	list_del_rcu(&mux->kcm_mux_list);
2082 	knet->count--;
2083 	mutex_unlock(&knet->mutex);
2084 
2085 	call_rcu(&mux->rcu, free_mux);
2086 }
2087 
2088 static void kcm_done(struct kcm_sock *kcm)
2089 {
2090 	struct kcm_mux *mux = kcm->mux;
2091 	struct sock *sk = &kcm->sk;
2092 	int socks_cnt;
2093 
2094 	spin_lock_bh(&mux->rx_lock);
2095 	if (kcm->rx_psock) {
2096 		/* Cleanup in unreserve_rx_kcm */
2097 		WARN_ON(kcm->done);
2098 		kcm->rx_disabled = 1;
2099 		kcm->done = 1;
2100 		spin_unlock_bh(&mux->rx_lock);
2101 		return;
2102 	}
2103 
2104 	if (kcm->rx_wait) {
2105 		list_del(&kcm->wait_rx_list);
2106 		kcm->rx_wait = false;
2107 	}
2108 	/* Move any pending receive messages to other kcm sockets */
2109 	requeue_rx_msgs(mux, &sk->sk_receive_queue);
2110 
2111 	spin_unlock_bh(&mux->rx_lock);
2112 
2113 	if (WARN_ON(sk_rmem_alloc_get(sk)))
2114 		return;
2115 
2116 	/* Detach from MUX */
2117 	spin_lock_bh(&mux->lock);
2118 
2119 	list_del(&kcm->kcm_sock_list);
2120 	mux->kcm_socks_cnt--;
2121 	socks_cnt = mux->kcm_socks_cnt;
2122 
2123 	spin_unlock_bh(&mux->lock);
2124 
2125 	if (!socks_cnt) {
2126 		/* We are done with the mux now. */
2127 		release_mux(mux);
2128 	}
2129 
2130 	WARN_ON(kcm->rx_wait);
2131 
2132 	sock_put(&kcm->sk);
2133 }
2134 
2135 /* Called by kcm_release to close a KCM socket.
2136  * If this is the last KCM socket on the MUX, destroy the MUX.
2137  */
2138 static int kcm_release(struct socket *sock)
2139 {
2140 	struct sock *sk = sock->sk;
2141 	struct kcm_sock *kcm;
2142 	struct kcm_mux *mux;
2143 	struct kcm_psock *psock;
2144 
2145 	if (!sk)
2146 		return 0;
2147 
2148 	kcm = kcm_sk(sk);
2149 	mux = kcm->mux;
2150 
2151 	sock_orphan(sk);
2152 	kfree_skb(kcm->seq_skb);
2153 
2154 	lock_sock(sk);
2155 	/* Purge queue under lock to avoid race condition with tx_work trying
2156 	 * to act when queue is nonempty. If tx_work runs after this point
2157 	 * it will just return.
2158 	 */
2159 	__skb_queue_purge(&sk->sk_write_queue);
2160 	release_sock(sk);
2161 
2162 	spin_lock_bh(&mux->lock);
2163 	if (kcm->tx_wait) {
2164 		/* Take of tx_wait list, after this point there should be no way
2165 		 * that a psock will be assigned to this kcm.
2166 		 */
2167 		list_del(&kcm->wait_psock_list);
2168 		kcm->tx_wait = false;
2169 	}
2170 	spin_unlock_bh(&mux->lock);
2171 
2172 	/* Cancel work. After this point there should be no outside references
2173 	 * to the kcm socket.
2174 	 */
2175 	cancel_work_sync(&kcm->tx_work);
2176 
2177 	lock_sock(sk);
2178 	psock = kcm->tx_psock;
2179 	if (psock) {
2180 		/* A psock was reserved, so we need to kill it since it
2181 		 * may already have some bytes queued from a message. We
2182 		 * need to do this after removing kcm from tx_wait list.
2183 		 */
2184 		kcm_abort_tx_psock(psock, EPIPE, false);
2185 		unreserve_psock(kcm);
2186 	}
2187 	release_sock(sk);
2188 
2189 	WARN_ON(kcm->tx_wait);
2190 	WARN_ON(kcm->tx_psock);
2191 
2192 	sock->sk = NULL;
2193 
2194 	kcm_done(kcm);
2195 
2196 	return 0;
2197 }
2198 
2199 static const struct proto_ops kcm_dgram_ops = {
2200 	.family =	PF_KCM,
2201 	.owner =	THIS_MODULE,
2202 	.release =	kcm_release,
2203 	.bind =		sock_no_bind,
2204 	.connect =	sock_no_connect,
2205 	.socketpair =	sock_no_socketpair,
2206 	.accept =	sock_no_accept,
2207 	.getname =	sock_no_getname,
2208 	.poll =		datagram_poll,
2209 	.ioctl =	kcm_ioctl,
2210 	.listen =	sock_no_listen,
2211 	.shutdown =	sock_no_shutdown,
2212 	.setsockopt =	kcm_setsockopt,
2213 	.getsockopt =	kcm_getsockopt,
2214 	.sendmsg =	kcm_sendmsg,
2215 	.recvmsg =	kcm_recvmsg,
2216 	.mmap =		sock_no_mmap,
2217 	.sendpage =	kcm_sendpage,
2218 };
2219 
2220 static const struct proto_ops kcm_seqpacket_ops = {
2221 	.family =	PF_KCM,
2222 	.owner =	THIS_MODULE,
2223 	.release =	kcm_release,
2224 	.bind =		sock_no_bind,
2225 	.connect =	sock_no_connect,
2226 	.socketpair =	sock_no_socketpair,
2227 	.accept =	sock_no_accept,
2228 	.getname =	sock_no_getname,
2229 	.poll =		datagram_poll,
2230 	.ioctl =	kcm_ioctl,
2231 	.listen =	sock_no_listen,
2232 	.shutdown =	sock_no_shutdown,
2233 	.setsockopt =	kcm_setsockopt,
2234 	.getsockopt =	kcm_getsockopt,
2235 	.sendmsg =	kcm_sendmsg,
2236 	.recvmsg =	kcm_recvmsg,
2237 	.mmap =		sock_no_mmap,
2238 	.sendpage =	kcm_sendpage,
2239 	.splice_read =	kcm_splice_read,
2240 };
2241 
2242 /* Create proto operation for kcm sockets */
2243 static int kcm_create(struct net *net, struct socket *sock,
2244 		      int protocol, int kern)
2245 {
2246 	struct kcm_net *knet = net_generic(net, kcm_net_id);
2247 	struct sock *sk;
2248 	struct kcm_mux *mux;
2249 
2250 	switch (sock->type) {
2251 	case SOCK_DGRAM:
2252 		sock->ops = &kcm_dgram_ops;
2253 		break;
2254 	case SOCK_SEQPACKET:
2255 		sock->ops = &kcm_seqpacket_ops;
2256 		break;
2257 	default:
2258 		return -ESOCKTNOSUPPORT;
2259 	}
2260 
2261 	if (protocol != KCMPROTO_CONNECTED)
2262 		return -EPROTONOSUPPORT;
2263 
2264 	sk = sk_alloc(net, PF_KCM, GFP_KERNEL, &kcm_proto, kern);
2265 	if (!sk)
2266 		return -ENOMEM;
2267 
2268 	/* Allocate a kcm mux, shared between KCM sockets */
2269 	mux = kmem_cache_zalloc(kcm_muxp, GFP_KERNEL);
2270 	if (!mux) {
2271 		sk_free(sk);
2272 		return -ENOMEM;
2273 	}
2274 
2275 	spin_lock_init(&mux->lock);
2276 	spin_lock_init(&mux->rx_lock);
2277 	INIT_LIST_HEAD(&mux->kcm_socks);
2278 	INIT_LIST_HEAD(&mux->kcm_rx_waiters);
2279 	INIT_LIST_HEAD(&mux->kcm_tx_waiters);
2280 
2281 	INIT_LIST_HEAD(&mux->psocks);
2282 	INIT_LIST_HEAD(&mux->psocks_ready);
2283 	INIT_LIST_HEAD(&mux->psocks_avail);
2284 
2285 	mux->knet = knet;
2286 
2287 	/* Add new MUX to list */
2288 	mutex_lock(&knet->mutex);
2289 	list_add_rcu(&mux->kcm_mux_list, &knet->mux_list);
2290 	knet->count++;
2291 	mutex_unlock(&knet->mutex);
2292 
2293 	skb_queue_head_init(&mux->rx_hold_queue);
2294 
2295 	/* Init KCM socket */
2296 	sock_init_data(sock, sk);
2297 	init_kcm_sock(kcm_sk(sk), mux);
2298 
2299 	return 0;
2300 }
2301 
2302 static struct net_proto_family kcm_family_ops = {
2303 	.family = PF_KCM,
2304 	.create = kcm_create,
2305 	.owner  = THIS_MODULE,
2306 };
2307 
2308 static __net_init int kcm_init_net(struct net *net)
2309 {
2310 	struct kcm_net *knet = net_generic(net, kcm_net_id);
2311 
2312 	INIT_LIST_HEAD_RCU(&knet->mux_list);
2313 	mutex_init(&knet->mutex);
2314 
2315 	return 0;
2316 }
2317 
2318 static __net_exit void kcm_exit_net(struct net *net)
2319 {
2320 	struct kcm_net *knet = net_generic(net, kcm_net_id);
2321 
2322 	/* All KCM sockets should be closed at this point, which should mean
2323 	 * that all multiplexors and psocks have been destroyed.
2324 	 */
2325 	WARN_ON(!list_empty(&knet->mux_list));
2326 }
2327 
2328 static struct pernet_operations kcm_net_ops = {
2329 	.init = kcm_init_net,
2330 	.exit = kcm_exit_net,
2331 	.id   = &kcm_net_id,
2332 	.size = sizeof(struct kcm_net),
2333 };
2334 
2335 static int __init kcm_init(void)
2336 {
2337 	int err = -ENOMEM;
2338 
2339 	kcm_muxp = kmem_cache_create("kcm_mux_cache",
2340 				     sizeof(struct kcm_mux), 0,
2341 				     SLAB_HWCACHE_ALIGN | SLAB_PANIC, NULL);
2342 	if (!kcm_muxp)
2343 		goto fail;
2344 
2345 	kcm_psockp = kmem_cache_create("kcm_psock_cache",
2346 				       sizeof(struct kcm_psock), 0,
2347 					SLAB_HWCACHE_ALIGN | SLAB_PANIC, NULL);
2348 	if (!kcm_psockp)
2349 		goto fail;
2350 
2351 	kcm_wq = create_singlethread_workqueue("kkcmd");
2352 	if (!kcm_wq)
2353 		goto fail;
2354 
2355 	err = proto_register(&kcm_proto, 1);
2356 	if (err)
2357 		goto fail;
2358 
2359 	err = sock_register(&kcm_family_ops);
2360 	if (err)
2361 		goto sock_register_fail;
2362 
2363 	err = register_pernet_device(&kcm_net_ops);
2364 	if (err)
2365 		goto net_ops_fail;
2366 
2367 	err = kcm_proc_init();
2368 	if (err)
2369 		goto proc_init_fail;
2370 
2371 	return 0;
2372 
2373 proc_init_fail:
2374 	unregister_pernet_device(&kcm_net_ops);
2375 
2376 net_ops_fail:
2377 	sock_unregister(PF_KCM);
2378 
2379 sock_register_fail:
2380 	proto_unregister(&kcm_proto);
2381 
2382 fail:
2383 	kmem_cache_destroy(kcm_muxp);
2384 	kmem_cache_destroy(kcm_psockp);
2385 
2386 	if (kcm_wq)
2387 		destroy_workqueue(kcm_wq);
2388 
2389 	return err;
2390 }
2391 
2392 static void __exit kcm_exit(void)
2393 {
2394 	kcm_proc_exit();
2395 	unregister_pernet_device(&kcm_net_ops);
2396 	sock_unregister(PF_KCM);
2397 	proto_unregister(&kcm_proto);
2398 	destroy_workqueue(kcm_wq);
2399 
2400 	kmem_cache_destroy(kcm_muxp);
2401 	kmem_cache_destroy(kcm_psockp);
2402 }
2403 
2404 module_init(kcm_init);
2405 module_exit(kcm_exit);
2406 
2407 MODULE_LICENSE("GPL");
2408 MODULE_ALIAS_NETPROTO(PF_KCM);
2409 
2410