xref: /openbmc/linux/net/core/skmsg.c (revision 113094f7)
1  // SPDX-License-Identifier: GPL-2.0
2  /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3  
4  #include <linux/skmsg.h>
5  #include <linux/skbuff.h>
6  #include <linux/scatterlist.h>
7  
8  #include <net/sock.h>
9  #include <net/tcp.h>
10  
11  static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
12  {
13  	if (msg->sg.end > msg->sg.start &&
14  	    elem_first_coalesce < msg->sg.end)
15  		return true;
16  
17  	if (msg->sg.end < msg->sg.start &&
18  	    (elem_first_coalesce > msg->sg.start ||
19  	     elem_first_coalesce < msg->sg.end))
20  		return true;
21  
22  	return false;
23  }
24  
25  int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
26  		 int elem_first_coalesce)
27  {
28  	struct page_frag *pfrag = sk_page_frag(sk);
29  	int ret = 0;
30  
31  	len -= msg->sg.size;
32  	while (len > 0) {
33  		struct scatterlist *sge;
34  		u32 orig_offset;
35  		int use, i;
36  
37  		if (!sk_page_frag_refill(sk, pfrag))
38  			return -ENOMEM;
39  
40  		orig_offset = pfrag->offset;
41  		use = min_t(int, len, pfrag->size - orig_offset);
42  		if (!sk_wmem_schedule(sk, use))
43  			return -ENOMEM;
44  
45  		i = msg->sg.end;
46  		sk_msg_iter_var_prev(i);
47  		sge = &msg->sg.data[i];
48  
49  		if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
50  		    sg_page(sge) == pfrag->page &&
51  		    sge->offset + sge->length == orig_offset) {
52  			sge->length += use;
53  		} else {
54  			if (sk_msg_full(msg)) {
55  				ret = -ENOSPC;
56  				break;
57  			}
58  
59  			sge = &msg->sg.data[msg->sg.end];
60  			sg_unmark_end(sge);
61  			sg_set_page(sge, pfrag->page, use, orig_offset);
62  			get_page(pfrag->page);
63  			sk_msg_iter_next(msg, end);
64  		}
65  
66  		sk_mem_charge(sk, use);
67  		msg->sg.size += use;
68  		pfrag->offset += use;
69  		len -= use;
70  	}
71  
72  	return ret;
73  }
74  EXPORT_SYMBOL_GPL(sk_msg_alloc);
75  
76  int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
77  		 u32 off, u32 len)
78  {
79  	int i = src->sg.start;
80  	struct scatterlist *sge = sk_msg_elem(src, i);
81  	struct scatterlist *sgd = NULL;
82  	u32 sge_len, sge_off;
83  
84  	while (off) {
85  		if (sge->length > off)
86  			break;
87  		off -= sge->length;
88  		sk_msg_iter_var_next(i);
89  		if (i == src->sg.end && off)
90  			return -ENOSPC;
91  		sge = sk_msg_elem(src, i);
92  	}
93  
94  	while (len) {
95  		sge_len = sge->length - off;
96  		if (sge_len > len)
97  			sge_len = len;
98  
99  		if (dst->sg.end)
100  			sgd = sk_msg_elem(dst, dst->sg.end - 1);
101  
102  		if (sgd &&
103  		    (sg_page(sge) == sg_page(sgd)) &&
104  		    (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) {
105  			sgd->length += sge_len;
106  			dst->sg.size += sge_len;
107  		} else if (!sk_msg_full(dst)) {
108  			sge_off = sge->offset + off;
109  			sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
110  		} else {
111  			return -ENOSPC;
112  		}
113  
114  		off = 0;
115  		len -= sge_len;
116  		sk_mem_charge(sk, sge_len);
117  		sk_msg_iter_var_next(i);
118  		if (i == src->sg.end && len)
119  			return -ENOSPC;
120  		sge = sk_msg_elem(src, i);
121  	}
122  
123  	return 0;
124  }
125  EXPORT_SYMBOL_GPL(sk_msg_clone);
126  
127  void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
128  {
129  	int i = msg->sg.start;
130  
131  	do {
132  		struct scatterlist *sge = sk_msg_elem(msg, i);
133  
134  		if (bytes < sge->length) {
135  			sge->length -= bytes;
136  			sge->offset += bytes;
137  			sk_mem_uncharge(sk, bytes);
138  			break;
139  		}
140  
141  		sk_mem_uncharge(sk, sge->length);
142  		bytes -= sge->length;
143  		sge->length = 0;
144  		sge->offset = 0;
145  		sk_msg_iter_var_next(i);
146  	} while (bytes && i != msg->sg.end);
147  	msg->sg.start = i;
148  }
149  EXPORT_SYMBOL_GPL(sk_msg_return_zero);
150  
151  void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
152  {
153  	int i = msg->sg.start;
154  
155  	do {
156  		struct scatterlist *sge = &msg->sg.data[i];
157  		int uncharge = (bytes < sge->length) ? bytes : sge->length;
158  
159  		sk_mem_uncharge(sk, uncharge);
160  		bytes -= uncharge;
161  		sk_msg_iter_var_next(i);
162  	} while (i != msg->sg.end);
163  }
164  EXPORT_SYMBOL_GPL(sk_msg_return);
165  
166  static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
167  			    bool charge)
168  {
169  	struct scatterlist *sge = sk_msg_elem(msg, i);
170  	u32 len = sge->length;
171  
172  	if (charge)
173  		sk_mem_uncharge(sk, len);
174  	if (!msg->skb)
175  		put_page(sg_page(sge));
176  	memset(sge, 0, sizeof(*sge));
177  	return len;
178  }
179  
180  static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
181  			 bool charge)
182  {
183  	struct scatterlist *sge = sk_msg_elem(msg, i);
184  	int freed = 0;
185  
186  	while (msg->sg.size) {
187  		msg->sg.size -= sge->length;
188  		freed += sk_msg_free_elem(sk, msg, i, charge);
189  		sk_msg_iter_var_next(i);
190  		sk_msg_check_to_free(msg, i, msg->sg.size);
191  		sge = sk_msg_elem(msg, i);
192  	}
193  	if (msg->skb)
194  		consume_skb(msg->skb);
195  	sk_msg_init(msg);
196  	return freed;
197  }
198  
199  int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
200  {
201  	return __sk_msg_free(sk, msg, msg->sg.start, false);
202  }
203  EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
204  
205  int sk_msg_free(struct sock *sk, struct sk_msg *msg)
206  {
207  	return __sk_msg_free(sk, msg, msg->sg.start, true);
208  }
209  EXPORT_SYMBOL_GPL(sk_msg_free);
210  
211  static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
212  				  u32 bytes, bool charge)
213  {
214  	struct scatterlist *sge;
215  	u32 i = msg->sg.start;
216  
217  	while (bytes) {
218  		sge = sk_msg_elem(msg, i);
219  		if (!sge->length)
220  			break;
221  		if (bytes < sge->length) {
222  			if (charge)
223  				sk_mem_uncharge(sk, bytes);
224  			sge->length -= bytes;
225  			sge->offset += bytes;
226  			msg->sg.size -= bytes;
227  			break;
228  		}
229  
230  		msg->sg.size -= sge->length;
231  		bytes -= sge->length;
232  		sk_msg_free_elem(sk, msg, i, charge);
233  		sk_msg_iter_var_next(i);
234  		sk_msg_check_to_free(msg, i, bytes);
235  	}
236  	msg->sg.start = i;
237  }
238  
239  void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
240  {
241  	__sk_msg_free_partial(sk, msg, bytes, true);
242  }
243  EXPORT_SYMBOL_GPL(sk_msg_free_partial);
244  
245  void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
246  				  u32 bytes)
247  {
248  	__sk_msg_free_partial(sk, msg, bytes, false);
249  }
250  
251  void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
252  {
253  	int trim = msg->sg.size - len;
254  	u32 i = msg->sg.end;
255  
256  	if (trim <= 0) {
257  		WARN_ON(trim < 0);
258  		return;
259  	}
260  
261  	sk_msg_iter_var_prev(i);
262  	msg->sg.size = len;
263  	while (msg->sg.data[i].length &&
264  	       trim >= msg->sg.data[i].length) {
265  		trim -= msg->sg.data[i].length;
266  		sk_msg_free_elem(sk, msg, i, true);
267  		sk_msg_iter_var_prev(i);
268  		if (!trim)
269  			goto out;
270  	}
271  
272  	msg->sg.data[i].length -= trim;
273  	sk_mem_uncharge(sk, trim);
274  out:
275  	/* If we trim data before curr pointer update copybreak and current
276  	 * so that any future copy operations start at new copy location.
277  	 * However trimed data that has not yet been used in a copy op
278  	 * does not require an update.
279  	 */
280  	if (msg->sg.curr >= i) {
281  		msg->sg.curr = i;
282  		msg->sg.copybreak = msg->sg.data[i].length;
283  	}
284  	sk_msg_iter_var_next(i);
285  	msg->sg.end = i;
286  }
287  EXPORT_SYMBOL_GPL(sk_msg_trim);
288  
289  int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
290  			      struct sk_msg *msg, u32 bytes)
291  {
292  	int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
293  	const int to_max_pages = MAX_MSG_FRAGS;
294  	struct page *pages[MAX_MSG_FRAGS];
295  	ssize_t orig, copied, use, offset;
296  
297  	orig = msg->sg.size;
298  	while (bytes > 0) {
299  		i = 0;
300  		maxpages = to_max_pages - num_elems;
301  		if (maxpages == 0) {
302  			ret = -EFAULT;
303  			goto out;
304  		}
305  
306  		copied = iov_iter_get_pages(from, pages, bytes, maxpages,
307  					    &offset);
308  		if (copied <= 0) {
309  			ret = -EFAULT;
310  			goto out;
311  		}
312  
313  		iov_iter_advance(from, copied);
314  		bytes -= copied;
315  		msg->sg.size += copied;
316  
317  		while (copied) {
318  			use = min_t(int, copied, PAGE_SIZE - offset);
319  			sg_set_page(&msg->sg.data[msg->sg.end],
320  				    pages[i], use, offset);
321  			sg_unmark_end(&msg->sg.data[msg->sg.end]);
322  			sk_mem_charge(sk, use);
323  
324  			offset = 0;
325  			copied -= use;
326  			sk_msg_iter_next(msg, end);
327  			num_elems++;
328  			i++;
329  		}
330  		/* When zerocopy is mixed with sk_msg_*copy* operations we
331  		 * may have a copybreak set in this case clear and prefer
332  		 * zerocopy remainder when possible.
333  		 */
334  		msg->sg.copybreak = 0;
335  		msg->sg.curr = msg->sg.end;
336  	}
337  out:
338  	/* Revert iov_iter updates, msg will need to use 'trim' later if it
339  	 * also needs to be cleared.
340  	 */
341  	if (ret)
342  		iov_iter_revert(from, msg->sg.size - orig);
343  	return ret;
344  }
345  EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
346  
347  int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
348  			     struct sk_msg *msg, u32 bytes)
349  {
350  	int ret = -ENOSPC, i = msg->sg.curr;
351  	struct scatterlist *sge;
352  	u32 copy, buf_size;
353  	void *to;
354  
355  	do {
356  		sge = sk_msg_elem(msg, i);
357  		/* This is possible if a trim operation shrunk the buffer */
358  		if (msg->sg.copybreak >= sge->length) {
359  			msg->sg.copybreak = 0;
360  			sk_msg_iter_var_next(i);
361  			if (i == msg->sg.end)
362  				break;
363  			sge = sk_msg_elem(msg, i);
364  		}
365  
366  		buf_size = sge->length - msg->sg.copybreak;
367  		copy = (buf_size > bytes) ? bytes : buf_size;
368  		to = sg_virt(sge) + msg->sg.copybreak;
369  		msg->sg.copybreak += copy;
370  		if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
371  			ret = copy_from_iter_nocache(to, copy, from);
372  		else
373  			ret = copy_from_iter(to, copy, from);
374  		if (ret != copy) {
375  			ret = -EFAULT;
376  			goto out;
377  		}
378  		bytes -= copy;
379  		if (!bytes)
380  			break;
381  		msg->sg.copybreak = 0;
382  		sk_msg_iter_var_next(i);
383  	} while (i != msg->sg.end);
384  out:
385  	msg->sg.curr = i;
386  	return ret;
387  }
388  EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
389  
390  static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
391  {
392  	struct sock *sk = psock->sk;
393  	int copied = 0, num_sge;
394  	struct sk_msg *msg;
395  
396  	msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
397  	if (unlikely(!msg))
398  		return -EAGAIN;
399  	if (!sk_rmem_schedule(sk, skb, skb->len)) {
400  		kfree(msg);
401  		return -EAGAIN;
402  	}
403  
404  	sk_msg_init(msg);
405  	num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
406  	if (unlikely(num_sge < 0)) {
407  		kfree(msg);
408  		return num_sge;
409  	}
410  
411  	sk_mem_charge(sk, skb->len);
412  	copied = skb->len;
413  	msg->sg.start = 0;
414  	msg->sg.size = copied;
415  	msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
416  	msg->skb = skb;
417  
418  	sk_psock_queue_msg(psock, msg);
419  	sk_psock_data_ready(sk, psock);
420  	return copied;
421  }
422  
423  static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
424  			       u32 off, u32 len, bool ingress)
425  {
426  	if (ingress)
427  		return sk_psock_skb_ingress(psock, skb);
428  	else
429  		return skb_send_sock_locked(psock->sk, skb, off, len);
430  }
431  
432  static void sk_psock_backlog(struct work_struct *work)
433  {
434  	struct sk_psock *psock = container_of(work, struct sk_psock, work);
435  	struct sk_psock_work_state *state = &psock->work_state;
436  	struct sk_buff *skb;
437  	bool ingress;
438  	u32 len, off;
439  	int ret;
440  
441  	/* Lock sock to avoid losing sk_socket during loop. */
442  	lock_sock(psock->sk);
443  	if (state->skb) {
444  		skb = state->skb;
445  		len = state->len;
446  		off = state->off;
447  		state->skb = NULL;
448  		goto start;
449  	}
450  
451  	while ((skb = skb_dequeue(&psock->ingress_skb))) {
452  		len = skb->len;
453  		off = 0;
454  start:
455  		ingress = tcp_skb_bpf_ingress(skb);
456  		do {
457  			ret = -EIO;
458  			if (likely(psock->sk->sk_socket))
459  				ret = sk_psock_handle_skb(psock, skb, off,
460  							  len, ingress);
461  			if (ret <= 0) {
462  				if (ret == -EAGAIN) {
463  					state->skb = skb;
464  					state->len = len;
465  					state->off = off;
466  					goto end;
467  				}
468  				/* Hard errors break pipe and stop xmit. */
469  				sk_psock_report_error(psock, ret ? -ret : EPIPE);
470  				sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
471  				kfree_skb(skb);
472  				goto end;
473  			}
474  			off += ret;
475  			len -= ret;
476  		} while (len);
477  
478  		if (!ingress)
479  			kfree_skb(skb);
480  	}
481  end:
482  	release_sock(psock->sk);
483  }
484  
485  struct sk_psock *sk_psock_init(struct sock *sk, int node)
486  {
487  	struct sk_psock *psock = kzalloc_node(sizeof(*psock),
488  					      GFP_ATOMIC | __GFP_NOWARN,
489  					      node);
490  	if (!psock)
491  		return NULL;
492  
493  	psock->sk = sk;
494  	psock->eval =  __SK_NONE;
495  
496  	INIT_LIST_HEAD(&psock->link);
497  	spin_lock_init(&psock->link_lock);
498  
499  	INIT_WORK(&psock->work, sk_psock_backlog);
500  	INIT_LIST_HEAD(&psock->ingress_msg);
501  	skb_queue_head_init(&psock->ingress_skb);
502  
503  	sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
504  	refcount_set(&psock->refcnt, 1);
505  
506  	rcu_assign_sk_user_data(sk, psock);
507  	sock_hold(sk);
508  
509  	return psock;
510  }
511  EXPORT_SYMBOL_GPL(sk_psock_init);
512  
513  struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
514  {
515  	struct sk_psock_link *link;
516  
517  	spin_lock_bh(&psock->link_lock);
518  	link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
519  					list);
520  	if (link)
521  		list_del(&link->list);
522  	spin_unlock_bh(&psock->link_lock);
523  	return link;
524  }
525  
526  void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
527  {
528  	struct sk_msg *msg, *tmp;
529  
530  	list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
531  		list_del(&msg->list);
532  		sk_msg_free(psock->sk, msg);
533  		kfree(msg);
534  	}
535  }
536  
537  static void sk_psock_zap_ingress(struct sk_psock *psock)
538  {
539  	__skb_queue_purge(&psock->ingress_skb);
540  	__sk_psock_purge_ingress_msg(psock);
541  }
542  
543  static void sk_psock_link_destroy(struct sk_psock *psock)
544  {
545  	struct sk_psock_link *link, *tmp;
546  
547  	list_for_each_entry_safe(link, tmp, &psock->link, list) {
548  		list_del(&link->list);
549  		sk_psock_free_link(link);
550  	}
551  }
552  
553  static void sk_psock_destroy_deferred(struct work_struct *gc)
554  {
555  	struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
556  
557  	/* No sk_callback_lock since already detached. */
558  
559  	/* Parser has been stopped */
560  	if (psock->progs.skb_parser)
561  		strp_done(&psock->parser.strp);
562  
563  	cancel_work_sync(&psock->work);
564  
565  	psock_progs_drop(&psock->progs);
566  
567  	sk_psock_link_destroy(psock);
568  	sk_psock_cork_free(psock);
569  	sk_psock_zap_ingress(psock);
570  
571  	if (psock->sk_redir)
572  		sock_put(psock->sk_redir);
573  	sock_put(psock->sk);
574  	kfree(psock);
575  }
576  
577  void sk_psock_destroy(struct rcu_head *rcu)
578  {
579  	struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
580  
581  	INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
582  	schedule_work(&psock->gc);
583  }
584  EXPORT_SYMBOL_GPL(sk_psock_destroy);
585  
586  void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
587  {
588  	rcu_assign_sk_user_data(sk, NULL);
589  	sk_psock_cork_free(psock);
590  	sk_psock_zap_ingress(psock);
591  	sk_psock_restore_proto(sk, psock);
592  
593  	write_lock_bh(&sk->sk_callback_lock);
594  	if (psock->progs.skb_parser)
595  		sk_psock_stop_strp(sk, psock);
596  	write_unlock_bh(&sk->sk_callback_lock);
597  	sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
598  
599  	call_rcu(&psock->rcu, sk_psock_destroy);
600  }
601  EXPORT_SYMBOL_GPL(sk_psock_drop);
602  
603  static int sk_psock_map_verd(int verdict, bool redir)
604  {
605  	switch (verdict) {
606  	case SK_PASS:
607  		return redir ? __SK_REDIRECT : __SK_PASS;
608  	case SK_DROP:
609  	default:
610  		break;
611  	}
612  
613  	return __SK_DROP;
614  }
615  
616  int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
617  			 struct sk_msg *msg)
618  {
619  	struct bpf_prog *prog;
620  	int ret;
621  
622  	preempt_disable();
623  	rcu_read_lock();
624  	prog = READ_ONCE(psock->progs.msg_parser);
625  	if (unlikely(!prog)) {
626  		ret = __SK_PASS;
627  		goto out;
628  	}
629  
630  	sk_msg_compute_data_pointers(msg);
631  	msg->sk = sk;
632  	ret = BPF_PROG_RUN(prog, msg);
633  	ret = sk_psock_map_verd(ret, msg->sk_redir);
634  	psock->apply_bytes = msg->apply_bytes;
635  	if (ret == __SK_REDIRECT) {
636  		if (psock->sk_redir)
637  			sock_put(psock->sk_redir);
638  		psock->sk_redir = msg->sk_redir;
639  		if (!psock->sk_redir) {
640  			ret = __SK_DROP;
641  			goto out;
642  		}
643  		sock_hold(psock->sk_redir);
644  	}
645  out:
646  	rcu_read_unlock();
647  	preempt_enable();
648  	return ret;
649  }
650  EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
651  
652  static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
653  			    struct sk_buff *skb)
654  {
655  	int ret;
656  
657  	skb->sk = psock->sk;
658  	bpf_compute_data_end_sk_skb(skb);
659  	preempt_disable();
660  	ret = BPF_PROG_RUN(prog, skb);
661  	preempt_enable();
662  	/* strparser clones the skb before handing it to a upper layer,
663  	 * meaning skb_orphan has been called. We NULL sk on the way out
664  	 * to ensure we don't trigger a BUG_ON() in skb/sk operations
665  	 * later and because we are not charging the memory of this skb
666  	 * to any socket yet.
667  	 */
668  	skb->sk = NULL;
669  	return ret;
670  }
671  
672  static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
673  {
674  	struct sk_psock_parser *parser;
675  
676  	parser = container_of(strp, struct sk_psock_parser, strp);
677  	return container_of(parser, struct sk_psock, parser);
678  }
679  
680  static void sk_psock_verdict_apply(struct sk_psock *psock,
681  				   struct sk_buff *skb, int verdict)
682  {
683  	struct sk_psock *psock_other;
684  	struct sock *sk_other;
685  	bool ingress;
686  
687  	switch (verdict) {
688  	case __SK_PASS:
689  		sk_other = psock->sk;
690  		if (sock_flag(sk_other, SOCK_DEAD) ||
691  		    !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
692  			goto out_free;
693  		}
694  		if (atomic_read(&sk_other->sk_rmem_alloc) <=
695  		    sk_other->sk_rcvbuf) {
696  			struct tcp_skb_cb *tcp = TCP_SKB_CB(skb);
697  
698  			tcp->bpf.flags |= BPF_F_INGRESS;
699  			skb_queue_tail(&psock->ingress_skb, skb);
700  			schedule_work(&psock->work);
701  			break;
702  		}
703  		goto out_free;
704  	case __SK_REDIRECT:
705  		sk_other = tcp_skb_bpf_redirect_fetch(skb);
706  		if (unlikely(!sk_other))
707  			goto out_free;
708  		psock_other = sk_psock(sk_other);
709  		if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
710  		    !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
711  			goto out_free;
712  		ingress = tcp_skb_bpf_ingress(skb);
713  		if ((!ingress && sock_writeable(sk_other)) ||
714  		    (ingress &&
715  		     atomic_read(&sk_other->sk_rmem_alloc) <=
716  		     sk_other->sk_rcvbuf)) {
717  			if (!ingress)
718  				skb_set_owner_w(skb, sk_other);
719  			skb_queue_tail(&psock_other->ingress_skb, skb);
720  			schedule_work(&psock_other->work);
721  			break;
722  		}
723  		/* fall-through */
724  	case __SK_DROP:
725  		/* fall-through */
726  	default:
727  out_free:
728  		kfree_skb(skb);
729  	}
730  }
731  
732  static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
733  {
734  	struct sk_psock *psock = sk_psock_from_strp(strp);
735  	struct bpf_prog *prog;
736  	int ret = __SK_DROP;
737  
738  	rcu_read_lock();
739  	prog = READ_ONCE(psock->progs.skb_verdict);
740  	if (likely(prog)) {
741  		skb_orphan(skb);
742  		tcp_skb_bpf_redirect_clear(skb);
743  		ret = sk_psock_bpf_run(psock, prog, skb);
744  		ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
745  	}
746  	rcu_read_unlock();
747  	sk_psock_verdict_apply(psock, skb, ret);
748  }
749  
750  static int sk_psock_strp_read_done(struct strparser *strp, int err)
751  {
752  	return err;
753  }
754  
755  static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
756  {
757  	struct sk_psock *psock = sk_psock_from_strp(strp);
758  	struct bpf_prog *prog;
759  	int ret = skb->len;
760  
761  	rcu_read_lock();
762  	prog = READ_ONCE(psock->progs.skb_parser);
763  	if (likely(prog))
764  		ret = sk_psock_bpf_run(psock, prog, skb);
765  	rcu_read_unlock();
766  	return ret;
767  }
768  
769  /* Called with socket lock held. */
770  static void sk_psock_strp_data_ready(struct sock *sk)
771  {
772  	struct sk_psock *psock;
773  
774  	rcu_read_lock();
775  	psock = sk_psock(sk);
776  	if (likely(psock)) {
777  		write_lock_bh(&sk->sk_callback_lock);
778  		strp_data_ready(&psock->parser.strp);
779  		write_unlock_bh(&sk->sk_callback_lock);
780  	}
781  	rcu_read_unlock();
782  }
783  
784  static void sk_psock_write_space(struct sock *sk)
785  {
786  	struct sk_psock *psock;
787  	void (*write_space)(struct sock *sk);
788  
789  	rcu_read_lock();
790  	psock = sk_psock(sk);
791  	if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
792  		schedule_work(&psock->work);
793  	write_space = psock->saved_write_space;
794  	rcu_read_unlock();
795  	write_space(sk);
796  }
797  
798  int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
799  {
800  	static const struct strp_callbacks cb = {
801  		.rcv_msg	= sk_psock_strp_read,
802  		.read_sock_done	= sk_psock_strp_read_done,
803  		.parse_msg	= sk_psock_strp_parse,
804  	};
805  
806  	psock->parser.enabled = false;
807  	return strp_init(&psock->parser.strp, sk, &cb);
808  }
809  
810  void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
811  {
812  	struct sk_psock_parser *parser = &psock->parser;
813  
814  	if (parser->enabled)
815  		return;
816  
817  	parser->saved_data_ready = sk->sk_data_ready;
818  	sk->sk_data_ready = sk_psock_strp_data_ready;
819  	sk->sk_write_space = sk_psock_write_space;
820  	parser->enabled = true;
821  }
822  
823  void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
824  {
825  	struct sk_psock_parser *parser = &psock->parser;
826  
827  	if (!parser->enabled)
828  		return;
829  
830  	sk->sk_data_ready = parser->saved_data_ready;
831  	parser->saved_data_ready = NULL;
832  	strp_stop(&parser->strp);
833  	parser->enabled = false;
834  }
835