xref: /openbmc/linux/net/tls/tls_strp.c (revision ffdd9bd7a278e37aa80de9ccc0b511d7387c2be7)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */
3 
4 #include <linux/skbuff.h>
5 #include <linux/workqueue.h>
6 #include <net/strparser.h>
7 #include <net/tcp.h>
8 #include <net/sock.h>
9 #include <net/tls.h>
10 
11 #include "tls.h"
12 
13 static struct workqueue_struct *tls_strp_wq;
14 
15 static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
16 {
17 	if (strp->stopped)
18 		return;
19 
20 	strp->stopped = 1;
21 
22 	/* Report an error on the lower socket */
23 	strp->sk->sk_err = -err;
24 	sk_error_report(strp->sk);
25 }
26 
27 static void tls_strp_anchor_free(struct tls_strparser *strp)
28 {
29 	struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
30 
31 	DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
32 	shinfo->frag_list = NULL;
33 	consume_skb(strp->anchor);
34 	strp->anchor = NULL;
35 }
36 
37 /* Create a new skb with the contents of input copied to its page frags */
38 static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
39 {
40 	struct strp_msg *rxm;
41 	struct sk_buff *skb;
42 	int i, err, offset;
43 
44 	skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
45 				   &err, strp->sk->sk_allocation);
46 	if (!skb)
47 		return NULL;
48 
49 	offset = strp->stm.offset;
50 	for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
51 		skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
52 
53 		WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
54 					   skb_frag_address(frag),
55 					   skb_frag_size(frag)));
56 		offset += skb_frag_size(frag);
57 	}
58 
59 	skb_copy_header(skb, strp->anchor);
60 	rxm = strp_msg(skb);
61 	rxm->offset = 0;
62 	return skb;
63 }
64 
65 /* Steal the input skb, input msg is invalid after calling this function */
66 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
67 {
68 	struct tls_strparser *strp = &ctx->strp;
69 
70 #ifdef CONFIG_TLS_DEVICE
71 	DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted);
72 #else
73 	/* This function turns an input into an output,
74 	 * that can only happen if we have offload.
75 	 */
76 	WARN_ON(1);
77 #endif
78 
79 	if (strp->copy_mode) {
80 		struct sk_buff *skb;
81 
82 		/* Replace anchor with an empty skb, this is a little
83 		 * dangerous but __tls_cur_msg() warns on empty skbs
84 		 * so hopefully we'll catch abuses.
85 		 */
86 		skb = alloc_skb(0, strp->sk->sk_allocation);
87 		if (!skb)
88 			return NULL;
89 
90 		swap(strp->anchor, skb);
91 		return skb;
92 	}
93 
94 	return tls_strp_msg_make_copy(strp);
95 }
96 
97 /* Force the input skb to be in copy mode. The data ownership remains
98  * with the input skb itself (meaning unpause will wipe it) but it can
99  * be modified.
100  */
101 int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
102 {
103 	struct tls_strparser *strp = &ctx->strp;
104 	struct sk_buff *skb;
105 
106 	if (strp->copy_mode)
107 		return 0;
108 
109 	skb = tls_strp_msg_make_copy(strp);
110 	if (!skb)
111 		return -ENOMEM;
112 
113 	tls_strp_anchor_free(strp);
114 	strp->anchor = skb;
115 
116 	tcp_read_done(strp->sk, strp->stm.full_len);
117 	strp->copy_mode = 1;
118 
119 	return 0;
120 }
121 
122 /* Make a clone (in the skb sense) of the input msg to keep a reference
123  * to the underlying data. The reference-holding skbs get placed on
124  * @dst.
125  */
126 int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst)
127 {
128 	struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
129 
130 	if (strp->copy_mode) {
131 		struct sk_buff *skb;
132 
133 		WARN_ON_ONCE(!shinfo->nr_frags);
134 
135 		/* We can't skb_clone() the anchor, it gets wiped by unpause */
136 		skb = alloc_skb(0, strp->sk->sk_allocation);
137 		if (!skb)
138 			return -ENOMEM;
139 
140 		__skb_queue_tail(dst, strp->anchor);
141 		strp->anchor = skb;
142 	} else {
143 		struct sk_buff *iter, *clone;
144 		int chunk, len, offset;
145 
146 		offset = strp->stm.offset;
147 		len = strp->stm.full_len;
148 		iter = shinfo->frag_list;
149 
150 		while (len > 0) {
151 			if (iter->len <= offset) {
152 				offset -= iter->len;
153 				goto next;
154 			}
155 
156 			chunk = iter->len - offset;
157 			offset = 0;
158 
159 			clone = skb_clone(iter, strp->sk->sk_allocation);
160 			if (!clone)
161 				return -ENOMEM;
162 			__skb_queue_tail(dst, clone);
163 
164 			len -= chunk;
165 next:
166 			iter = iter->next;
167 		}
168 	}
169 
170 	return 0;
171 }
172 
173 static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
174 {
175 	struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
176 	int i;
177 
178 	DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
179 
180 	for (i = 0; i < shinfo->nr_frags; i++)
181 		__skb_frag_unref(&shinfo->frags[i], false);
182 	shinfo->nr_frags = 0;
183 	strp->copy_mode = 0;
184 }
185 
186 static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
187 			   unsigned int offset, size_t in_len)
188 {
189 	struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
190 	struct sk_buff *skb;
191 	skb_frag_t *frag;
192 	size_t len, chunk;
193 	int sz;
194 
195 	if (strp->msg_ready)
196 		return 0;
197 
198 	skb = strp->anchor;
199 	frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
200 
201 	len = in_len;
202 	/* First make sure we got the header */
203 	if (!strp->stm.full_len) {
204 		/* Assume one page is more than enough for headers */
205 		chunk =	min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag));
206 		WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
207 					   skb_frag_address(frag) +
208 					   skb_frag_size(frag),
209 					   chunk));
210 
211 		sz = tls_rx_msg_size(strp, strp->anchor);
212 		if (sz < 0) {
213 			desc->error = sz;
214 			return 0;
215 		}
216 
217 		/* We may have over-read, sz == 0 is guaranteed under-read */
218 		if (sz > 0)
219 			chunk =	min_t(size_t, chunk, sz - skb->len);
220 
221 		skb->len += chunk;
222 		skb->data_len += chunk;
223 		skb_frag_size_add(frag, chunk);
224 		frag++;
225 		len -= chunk;
226 		offset += chunk;
227 
228 		strp->stm.full_len = sz;
229 		if (!strp->stm.full_len)
230 			goto read_done;
231 	}
232 
233 	/* Load up more data */
234 	while (len && strp->stm.full_len > skb->len) {
235 		chunk =	min_t(size_t, len, strp->stm.full_len - skb->len);
236 		chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag));
237 		WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
238 					   skb_frag_address(frag) +
239 					   skb_frag_size(frag),
240 					   chunk));
241 
242 		skb->len += chunk;
243 		skb->data_len += chunk;
244 		skb_frag_size_add(frag, chunk);
245 		frag++;
246 		len -= chunk;
247 		offset += chunk;
248 	}
249 
250 	if (strp->stm.full_len == skb->len) {
251 		desc->count = 0;
252 
253 		strp->msg_ready = 1;
254 		tls_rx_msg_ready(strp);
255 	}
256 
257 read_done:
258 	return in_len - len;
259 }
260 
261 static int tls_strp_read_copyin(struct tls_strparser *strp)
262 {
263 	struct socket *sock = strp->sk->sk_socket;
264 	read_descriptor_t desc;
265 
266 	desc.arg.data = strp;
267 	desc.error = 0;
268 	desc.count = 1; /* give more than one skb per call */
269 
270 	/* sk should be locked here, so okay to do read_sock */
271 	sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin);
272 
273 	return desc.error;
274 }
275 
276 static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
277 {
278 	struct skb_shared_info *shinfo;
279 	struct page *page;
280 	int need_spc, len;
281 
282 	/* If the rbuf is small or rcv window has collapsed to 0 we need
283 	 * to read the data out. Otherwise the connection will stall.
284 	 * Without pressure threshold of INT_MAX will never be ready.
285 	 */
286 	if (likely(qshort && !tcp_epollin_ready(strp->sk, INT_MAX)))
287 		return 0;
288 
289 	shinfo = skb_shinfo(strp->anchor);
290 	shinfo->frag_list = NULL;
291 
292 	/* If we don't know the length go max plus page for cipher overhead */
293 	need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
294 
295 	for (len = need_spc; len > 0; len -= PAGE_SIZE) {
296 		page = alloc_page(strp->sk->sk_allocation);
297 		if (!page) {
298 			tls_strp_flush_anchor_copy(strp);
299 			return -ENOMEM;
300 		}
301 
302 		skb_fill_page_desc(strp->anchor, shinfo->nr_frags++,
303 				   page, 0, 0);
304 	}
305 
306 	strp->copy_mode = 1;
307 	strp->stm.offset = 0;
308 
309 	strp->anchor->len = 0;
310 	strp->anchor->data_len = 0;
311 	strp->anchor->truesize = round_up(need_spc, PAGE_SIZE);
312 
313 	tls_strp_read_copyin(strp);
314 
315 	return 0;
316 }
317 
318 static bool tls_strp_check_no_dup(struct tls_strparser *strp)
319 {
320 	unsigned int len = strp->stm.offset + strp->stm.full_len;
321 	struct sk_buff *skb;
322 	u32 seq;
323 
324 	skb = skb_shinfo(strp->anchor)->frag_list;
325 	seq = TCP_SKB_CB(skb)->seq;
326 
327 	while (skb->len < len) {
328 		seq += skb->len;
329 		len -= skb->len;
330 		skb = skb->next;
331 
332 		if (TCP_SKB_CB(skb)->seq != seq)
333 			return false;
334 	}
335 
336 	return true;
337 }
338 
339 static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
340 {
341 	struct tcp_sock *tp = tcp_sk(strp->sk);
342 	struct sk_buff *first;
343 	u32 offset;
344 
345 	first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
346 	if (WARN_ON_ONCE(!first))
347 		return;
348 
349 	/* Bestow the state onto the anchor */
350 	strp->anchor->len = offset + len;
351 	strp->anchor->data_len = offset + len;
352 	strp->anchor->truesize = offset + len;
353 
354 	skb_shinfo(strp->anchor)->frag_list = first;
355 
356 	skb_copy_header(strp->anchor, first);
357 	strp->anchor->destructor = NULL;
358 
359 	strp->stm.offset = offset;
360 }
361 
362 void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
363 {
364 	struct strp_msg *rxm;
365 	struct tls_msg *tlm;
366 
367 	DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready);
368 	DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
369 
370 	if (!strp->copy_mode && force_refresh) {
371 		if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
372 			return;
373 
374 		tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
375 	}
376 
377 	rxm = strp_msg(strp->anchor);
378 	rxm->full_len	= strp->stm.full_len;
379 	rxm->offset	= strp->stm.offset;
380 	tlm = tls_msg(strp->anchor);
381 	tlm->control	= strp->mark;
382 }
383 
384 /* Called with lock held on lower socket */
385 static int tls_strp_read_sock(struct tls_strparser *strp)
386 {
387 	int sz, inq;
388 
389 	inq = tcp_inq(strp->sk);
390 	if (inq < 1)
391 		return 0;
392 
393 	if (unlikely(strp->copy_mode))
394 		return tls_strp_read_copyin(strp);
395 
396 	if (inq < strp->stm.full_len)
397 		return tls_strp_read_copy(strp, true);
398 
399 	if (!strp->stm.full_len) {
400 		tls_strp_load_anchor_with_queue(strp, inq);
401 
402 		sz = tls_rx_msg_size(strp, strp->anchor);
403 		if (sz < 0) {
404 			tls_strp_abort_strp(strp, sz);
405 			return sz;
406 		}
407 
408 		strp->stm.full_len = sz;
409 
410 		if (!strp->stm.full_len || inq < strp->stm.full_len)
411 			return tls_strp_read_copy(strp, true);
412 	}
413 
414 	if (!tls_strp_check_no_dup(strp))
415 		return tls_strp_read_copy(strp, false);
416 
417 	strp->msg_ready = 1;
418 	tls_rx_msg_ready(strp);
419 
420 	return 0;
421 }
422 
423 void tls_strp_check_rcv(struct tls_strparser *strp)
424 {
425 	if (unlikely(strp->stopped) || strp->msg_ready)
426 		return;
427 
428 	if (tls_strp_read_sock(strp) == -ENOMEM)
429 		queue_work(tls_strp_wq, &strp->work);
430 }
431 
432 /* Lower sock lock held */
433 void tls_strp_data_ready(struct tls_strparser *strp)
434 {
435 	/* This check is needed to synchronize with do_tls_strp_work.
436 	 * do_tls_strp_work acquires a process lock (lock_sock) whereas
437 	 * the lock held here is bh_lock_sock. The two locks can be
438 	 * held by different threads at the same time, but bh_lock_sock
439 	 * allows a thread in BH context to safely check if the process
440 	 * lock is held. In this case, if the lock is held, queue work.
441 	 */
442 	if (sock_owned_by_user_nocheck(strp->sk)) {
443 		queue_work(tls_strp_wq, &strp->work);
444 		return;
445 	}
446 
447 	tls_strp_check_rcv(strp);
448 }
449 
450 static void tls_strp_work(struct work_struct *w)
451 {
452 	struct tls_strparser *strp =
453 		container_of(w, struct tls_strparser, work);
454 
455 	lock_sock(strp->sk);
456 	tls_strp_check_rcv(strp);
457 	release_sock(strp->sk);
458 }
459 
460 void tls_strp_msg_done(struct tls_strparser *strp)
461 {
462 	WARN_ON(!strp->stm.full_len);
463 
464 	if (likely(!strp->copy_mode))
465 		tcp_read_done(strp->sk, strp->stm.full_len);
466 	else
467 		tls_strp_flush_anchor_copy(strp);
468 
469 	strp->msg_ready = 0;
470 	memset(&strp->stm, 0, sizeof(strp->stm));
471 
472 	tls_strp_check_rcv(strp);
473 }
474 
475 void tls_strp_stop(struct tls_strparser *strp)
476 {
477 	strp->stopped = 1;
478 }
479 
480 int tls_strp_init(struct tls_strparser *strp, struct sock *sk)
481 {
482 	memset(strp, 0, sizeof(*strp));
483 
484 	strp->sk = sk;
485 
486 	strp->anchor = alloc_skb(0, GFP_KERNEL);
487 	if (!strp->anchor)
488 		return -ENOMEM;
489 
490 	INIT_WORK(&strp->work, tls_strp_work);
491 
492 	return 0;
493 }
494 
495 /* strp must already be stopped so that tls_strp_recv will no longer be called.
496  * Note that tls_strp_done is not called with the lower socket held.
497  */
498 void tls_strp_done(struct tls_strparser *strp)
499 {
500 	WARN_ON(!strp->stopped);
501 
502 	cancel_work_sync(&strp->work);
503 	tls_strp_anchor_free(strp);
504 }
505 
506 int __init tls_strp_dev_init(void)
507 {
508 	tls_strp_wq = create_workqueue("tls-strp");
509 	if (unlikely(!tls_strp_wq))
510 		return -ENOMEM;
511 
512 	return 0;
513 }
514 
515 void tls_strp_dev_exit(void)
516 {
517 	destroy_workqueue(tls_strp_wq);
518 }
519