xref: /openbmc/linux/net/tls/tls_sw.c (revision f1288bdb)
1  /*
2   * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3   * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4   * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5   * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6   * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7   * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
8   *
9   * This software is available to you under a choice of one of two
10   * licenses.  You may choose to be licensed under the terms of the GNU
11   * General Public License (GPL) Version 2, available from the file
12   * COPYING in the main directory of this source tree, or the
13   * OpenIB.org BSD license below:
14   *
15   *     Redistribution and use in source and binary forms, with or
16   *     without modification, are permitted provided that the following
17   *     conditions are met:
18   *
19   *      - Redistributions of source code must retain the above
20   *        copyright notice, this list of conditions and the following
21   *        disclaimer.
22   *
23   *      - Redistributions in binary form must reproduce the above
24   *        copyright notice, this list of conditions and the following
25   *        disclaimer in the documentation and/or other materials
26   *        provided with the distribution.
27   *
28   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
29   * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
30   * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
31   * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
32   * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
33   * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
34   * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35   * SOFTWARE.
36   */
37  
38  #include <linux/bug.h>
39  #include <linux/sched/signal.h>
40  #include <linux/module.h>
41  #include <linux/splice.h>
42  #include <crypto/aead.h>
43  
44  #include <net/strparser.h>
45  #include <net/tls.h>
46  
47  #include "tls.h"
48  
49  struct tls_decrypt_arg {
50  	struct_group(inargs,
51  	bool zc;
52  	bool async;
53  	u8 tail;
54  	);
55  
56  	struct sk_buff *skb;
57  };
58  
59  struct tls_decrypt_ctx {
60  	u8 iv[MAX_IV_SIZE];
61  	u8 aad[TLS_MAX_AAD_SIZE];
62  	u8 tail;
63  	struct scatterlist sg[];
64  };
65  
66  noinline void tls_err_abort(struct sock *sk, int err)
67  {
68  	WARN_ON_ONCE(err >= 0);
69  	/* sk->sk_err should contain a positive error code. */
70  	sk->sk_err = -err;
71  	sk_error_report(sk);
72  }
73  
74  static int __skb_nsg(struct sk_buff *skb, int offset, int len,
75                       unsigned int recursion_level)
76  {
77          int start = skb_headlen(skb);
78          int i, chunk = start - offset;
79          struct sk_buff *frag_iter;
80          int elt = 0;
81  
82          if (unlikely(recursion_level >= 24))
83                  return -EMSGSIZE;
84  
85          if (chunk > 0) {
86                  if (chunk > len)
87                          chunk = len;
88                  elt++;
89                  len -= chunk;
90                  if (len == 0)
91                          return elt;
92                  offset += chunk;
93          }
94  
95          for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
96                  int end;
97  
98                  WARN_ON(start > offset + len);
99  
100                  end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
101                  chunk = end - offset;
102                  if (chunk > 0) {
103                          if (chunk > len)
104                                  chunk = len;
105                          elt++;
106                          len -= chunk;
107                          if (len == 0)
108                                  return elt;
109                          offset += chunk;
110                  }
111                  start = end;
112          }
113  
114          if (unlikely(skb_has_frag_list(skb))) {
115                  skb_walk_frags(skb, frag_iter) {
116                          int end, ret;
117  
118                          WARN_ON(start > offset + len);
119  
120                          end = start + frag_iter->len;
121                          chunk = end - offset;
122                          if (chunk > 0) {
123                                  if (chunk > len)
124                                          chunk = len;
125                                  ret = __skb_nsg(frag_iter, offset - start, chunk,
126                                                  recursion_level + 1);
127                                  if (unlikely(ret < 0))
128                                          return ret;
129                                  elt += ret;
130                                  len -= chunk;
131                                  if (len == 0)
132                                          return elt;
133                                  offset += chunk;
134                          }
135                          start = end;
136                  }
137          }
138          BUG_ON(len);
139          return elt;
140  }
141  
142  /* Return the number of scatterlist elements required to completely map the
143   * skb, or -EMSGSIZE if the recursion depth is exceeded.
144   */
145  static int skb_nsg(struct sk_buff *skb, int offset, int len)
146  {
147          return __skb_nsg(skb, offset, len, 0);
148  }
149  
150  static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
151  			      struct tls_decrypt_arg *darg)
152  {
153  	struct strp_msg *rxm = strp_msg(skb);
154  	struct tls_msg *tlm = tls_msg(skb);
155  	int sub = 0;
156  
157  	/* Determine zero-padding length */
158  	if (prot->version == TLS_1_3_VERSION) {
159  		int offset = rxm->full_len - TLS_TAG_SIZE - 1;
160  		char content_type = darg->zc ? darg->tail : 0;
161  		int err;
162  
163  		while (content_type == 0) {
164  			if (offset < prot->prepend_size)
165  				return -EBADMSG;
166  			err = skb_copy_bits(skb, rxm->offset + offset,
167  					    &content_type, 1);
168  			if (err)
169  				return err;
170  			if (content_type)
171  				break;
172  			sub++;
173  			offset--;
174  		}
175  		tlm->control = content_type;
176  	}
177  	return sub;
178  }
179  
180  static void tls_decrypt_done(struct crypto_async_request *req, int err)
181  {
182  	struct aead_request *aead_req = (struct aead_request *)req;
183  	struct scatterlist *sgout = aead_req->dst;
184  	struct scatterlist *sgin = aead_req->src;
185  	struct tls_sw_context_rx *ctx;
186  	struct tls_context *tls_ctx;
187  	struct scatterlist *sg;
188  	unsigned int pages;
189  	struct sock *sk;
190  
191  	sk = (struct sock *)req->data;
192  	tls_ctx = tls_get_ctx(sk);
193  	ctx = tls_sw_ctx_rx(tls_ctx);
194  
195  	/* Propagate if there was an err */
196  	if (err) {
197  		if (err == -EBADMSG)
198  			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
199  		ctx->async_wait.err = err;
200  		tls_err_abort(sk, err);
201  	}
202  
203  	/* Free the destination pages if skb was not decrypted inplace */
204  	if (sgout != sgin) {
205  		/* Skip the first S/G entry as it points to AAD */
206  		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
207  			if (!sg)
208  				break;
209  			put_page(sg_page(sg));
210  		}
211  	}
212  
213  	kfree(aead_req);
214  
215  	spin_lock_bh(&ctx->decrypt_compl_lock);
216  	if (!atomic_dec_return(&ctx->decrypt_pending))
217  		complete(&ctx->async_wait.completion);
218  	spin_unlock_bh(&ctx->decrypt_compl_lock);
219  }
220  
221  static int tls_do_decryption(struct sock *sk,
222  			     struct scatterlist *sgin,
223  			     struct scatterlist *sgout,
224  			     char *iv_recv,
225  			     size_t data_len,
226  			     struct aead_request *aead_req,
227  			     struct tls_decrypt_arg *darg)
228  {
229  	struct tls_context *tls_ctx = tls_get_ctx(sk);
230  	struct tls_prot_info *prot = &tls_ctx->prot_info;
231  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
232  	int ret;
233  
234  	aead_request_set_tfm(aead_req, ctx->aead_recv);
235  	aead_request_set_ad(aead_req, prot->aad_size);
236  	aead_request_set_crypt(aead_req, sgin, sgout,
237  			       data_len + prot->tag_size,
238  			       (u8 *)iv_recv);
239  
240  	if (darg->async) {
241  		aead_request_set_callback(aead_req,
242  					  CRYPTO_TFM_REQ_MAY_BACKLOG,
243  					  tls_decrypt_done, sk);
244  		atomic_inc(&ctx->decrypt_pending);
245  	} else {
246  		aead_request_set_callback(aead_req,
247  					  CRYPTO_TFM_REQ_MAY_BACKLOG,
248  					  crypto_req_done, &ctx->async_wait);
249  	}
250  
251  	ret = crypto_aead_decrypt(aead_req);
252  	if (ret == -EINPROGRESS) {
253  		if (darg->async)
254  			return 0;
255  
256  		ret = crypto_wait_req(ret, &ctx->async_wait);
257  	}
258  	darg->async = false;
259  
260  	return ret;
261  }
262  
263  static void tls_trim_both_msgs(struct sock *sk, int target_size)
264  {
265  	struct tls_context *tls_ctx = tls_get_ctx(sk);
266  	struct tls_prot_info *prot = &tls_ctx->prot_info;
267  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
268  	struct tls_rec *rec = ctx->open_rec;
269  
270  	sk_msg_trim(sk, &rec->msg_plaintext, target_size);
271  	if (target_size > 0)
272  		target_size += prot->overhead_size;
273  	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
274  }
275  
276  static int tls_alloc_encrypted_msg(struct sock *sk, int len)
277  {
278  	struct tls_context *tls_ctx = tls_get_ctx(sk);
279  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
280  	struct tls_rec *rec = ctx->open_rec;
281  	struct sk_msg *msg_en = &rec->msg_encrypted;
282  
283  	return sk_msg_alloc(sk, msg_en, len, 0);
284  }
285  
286  static int tls_clone_plaintext_msg(struct sock *sk, int required)
287  {
288  	struct tls_context *tls_ctx = tls_get_ctx(sk);
289  	struct tls_prot_info *prot = &tls_ctx->prot_info;
290  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
291  	struct tls_rec *rec = ctx->open_rec;
292  	struct sk_msg *msg_pl = &rec->msg_plaintext;
293  	struct sk_msg *msg_en = &rec->msg_encrypted;
294  	int skip, len;
295  
296  	/* We add page references worth len bytes from encrypted sg
297  	 * at the end of plaintext sg. It is guaranteed that msg_en
298  	 * has enough required room (ensured by caller).
299  	 */
300  	len = required - msg_pl->sg.size;
301  
302  	/* Skip initial bytes in msg_en's data to be able to use
303  	 * same offset of both plain and encrypted data.
304  	 */
305  	skip = prot->prepend_size + msg_pl->sg.size;
306  
307  	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
308  }
309  
310  static struct tls_rec *tls_get_rec(struct sock *sk)
311  {
312  	struct tls_context *tls_ctx = tls_get_ctx(sk);
313  	struct tls_prot_info *prot = &tls_ctx->prot_info;
314  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
315  	struct sk_msg *msg_pl, *msg_en;
316  	struct tls_rec *rec;
317  	int mem_size;
318  
319  	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
320  
321  	rec = kzalloc(mem_size, sk->sk_allocation);
322  	if (!rec)
323  		return NULL;
324  
325  	msg_pl = &rec->msg_plaintext;
326  	msg_en = &rec->msg_encrypted;
327  
328  	sk_msg_init(msg_pl);
329  	sk_msg_init(msg_en);
330  
331  	sg_init_table(rec->sg_aead_in, 2);
332  	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
333  	sg_unmark_end(&rec->sg_aead_in[1]);
334  
335  	sg_init_table(rec->sg_aead_out, 2);
336  	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
337  	sg_unmark_end(&rec->sg_aead_out[1]);
338  
339  	return rec;
340  }
341  
342  static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
343  {
344  	sk_msg_free(sk, &rec->msg_encrypted);
345  	sk_msg_free(sk, &rec->msg_plaintext);
346  	kfree(rec);
347  }
348  
349  static void tls_free_open_rec(struct sock *sk)
350  {
351  	struct tls_context *tls_ctx = tls_get_ctx(sk);
352  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
353  	struct tls_rec *rec = ctx->open_rec;
354  
355  	if (rec) {
356  		tls_free_rec(sk, rec);
357  		ctx->open_rec = NULL;
358  	}
359  }
360  
361  int tls_tx_records(struct sock *sk, int flags)
362  {
363  	struct tls_context *tls_ctx = tls_get_ctx(sk);
364  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
365  	struct tls_rec *rec, *tmp;
366  	struct sk_msg *msg_en;
367  	int tx_flags, rc = 0;
368  
369  	if (tls_is_partially_sent_record(tls_ctx)) {
370  		rec = list_first_entry(&ctx->tx_list,
371  				       struct tls_rec, list);
372  
373  		if (flags == -1)
374  			tx_flags = rec->tx_flags;
375  		else
376  			tx_flags = flags;
377  
378  		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
379  		if (rc)
380  			goto tx_err;
381  
382  		/* Full record has been transmitted.
383  		 * Remove the head of tx_list
384  		 */
385  		list_del(&rec->list);
386  		sk_msg_free(sk, &rec->msg_plaintext);
387  		kfree(rec);
388  	}
389  
390  	/* Tx all ready records */
391  	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
392  		if (READ_ONCE(rec->tx_ready)) {
393  			if (flags == -1)
394  				tx_flags = rec->tx_flags;
395  			else
396  				tx_flags = flags;
397  
398  			msg_en = &rec->msg_encrypted;
399  			rc = tls_push_sg(sk, tls_ctx,
400  					 &msg_en->sg.data[msg_en->sg.curr],
401  					 0, tx_flags);
402  			if (rc)
403  				goto tx_err;
404  
405  			list_del(&rec->list);
406  			sk_msg_free(sk, &rec->msg_plaintext);
407  			kfree(rec);
408  		} else {
409  			break;
410  		}
411  	}
412  
413  tx_err:
414  	if (rc < 0 && rc != -EAGAIN)
415  		tls_err_abort(sk, -EBADMSG);
416  
417  	return rc;
418  }
419  
420  static void tls_encrypt_done(struct crypto_async_request *req, int err)
421  {
422  	struct aead_request *aead_req = (struct aead_request *)req;
423  	struct sock *sk = req->data;
424  	struct tls_context *tls_ctx = tls_get_ctx(sk);
425  	struct tls_prot_info *prot = &tls_ctx->prot_info;
426  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
427  	struct scatterlist *sge;
428  	struct sk_msg *msg_en;
429  	struct tls_rec *rec;
430  	bool ready = false;
431  	int pending;
432  
433  	rec = container_of(aead_req, struct tls_rec, aead_req);
434  	msg_en = &rec->msg_encrypted;
435  
436  	sge = sk_msg_elem(msg_en, msg_en->sg.curr);
437  	sge->offset -= prot->prepend_size;
438  	sge->length += prot->prepend_size;
439  
440  	/* Check if error is previously set on socket */
441  	if (err || sk->sk_err) {
442  		rec = NULL;
443  
444  		/* If err is already set on socket, return the same code */
445  		if (sk->sk_err) {
446  			ctx->async_wait.err = -sk->sk_err;
447  		} else {
448  			ctx->async_wait.err = err;
449  			tls_err_abort(sk, err);
450  		}
451  	}
452  
453  	if (rec) {
454  		struct tls_rec *first_rec;
455  
456  		/* Mark the record as ready for transmission */
457  		smp_store_mb(rec->tx_ready, true);
458  
459  		/* If received record is at head of tx_list, schedule tx */
460  		first_rec = list_first_entry(&ctx->tx_list,
461  					     struct tls_rec, list);
462  		if (rec == first_rec)
463  			ready = true;
464  	}
465  
466  	spin_lock_bh(&ctx->encrypt_compl_lock);
467  	pending = atomic_dec_return(&ctx->encrypt_pending);
468  
469  	if (!pending && ctx->async_notify)
470  		complete(&ctx->async_wait.completion);
471  	spin_unlock_bh(&ctx->encrypt_compl_lock);
472  
473  	if (!ready)
474  		return;
475  
476  	/* Schedule the transmission */
477  	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
478  		schedule_delayed_work(&ctx->tx_work.work, 1);
479  }
480  
481  static int tls_do_encryption(struct sock *sk,
482  			     struct tls_context *tls_ctx,
483  			     struct tls_sw_context_tx *ctx,
484  			     struct aead_request *aead_req,
485  			     size_t data_len, u32 start)
486  {
487  	struct tls_prot_info *prot = &tls_ctx->prot_info;
488  	struct tls_rec *rec = ctx->open_rec;
489  	struct sk_msg *msg_en = &rec->msg_encrypted;
490  	struct scatterlist *sge = sk_msg_elem(msg_en, start);
491  	int rc, iv_offset = 0;
492  
493  	/* For CCM based ciphers, first byte of IV is a constant */
494  	switch (prot->cipher_type) {
495  	case TLS_CIPHER_AES_CCM_128:
496  		rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
497  		iv_offset = 1;
498  		break;
499  	case TLS_CIPHER_SM4_CCM:
500  		rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
501  		iv_offset = 1;
502  		break;
503  	}
504  
505  	memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
506  	       prot->iv_size + prot->salt_size);
507  
508  	tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
509  			    tls_ctx->tx.rec_seq);
510  
511  	sge->offset += prot->prepend_size;
512  	sge->length -= prot->prepend_size;
513  
514  	msg_en->sg.curr = start;
515  
516  	aead_request_set_tfm(aead_req, ctx->aead_send);
517  	aead_request_set_ad(aead_req, prot->aad_size);
518  	aead_request_set_crypt(aead_req, rec->sg_aead_in,
519  			       rec->sg_aead_out,
520  			       data_len, rec->iv_data);
521  
522  	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
523  				  tls_encrypt_done, sk);
524  
525  	/* Add the record in tx_list */
526  	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
527  	atomic_inc(&ctx->encrypt_pending);
528  
529  	rc = crypto_aead_encrypt(aead_req);
530  	if (!rc || rc != -EINPROGRESS) {
531  		atomic_dec(&ctx->encrypt_pending);
532  		sge->offset -= prot->prepend_size;
533  		sge->length += prot->prepend_size;
534  	}
535  
536  	if (!rc) {
537  		WRITE_ONCE(rec->tx_ready, true);
538  	} else if (rc != -EINPROGRESS) {
539  		list_del(&rec->list);
540  		return rc;
541  	}
542  
543  	/* Unhook the record from context if encryption is not failure */
544  	ctx->open_rec = NULL;
545  	tls_advance_record_sn(sk, prot, &tls_ctx->tx);
546  	return rc;
547  }
548  
549  static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
550  				 struct tls_rec **to, struct sk_msg *msg_opl,
551  				 struct sk_msg *msg_oen, u32 split_point,
552  				 u32 tx_overhead_size, u32 *orig_end)
553  {
554  	u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
555  	struct scatterlist *sge, *osge, *nsge;
556  	u32 orig_size = msg_opl->sg.size;
557  	struct scatterlist tmp = { };
558  	struct sk_msg *msg_npl;
559  	struct tls_rec *new;
560  	int ret;
561  
562  	new = tls_get_rec(sk);
563  	if (!new)
564  		return -ENOMEM;
565  	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
566  			   tx_overhead_size, 0);
567  	if (ret < 0) {
568  		tls_free_rec(sk, new);
569  		return ret;
570  	}
571  
572  	*orig_end = msg_opl->sg.end;
573  	i = msg_opl->sg.start;
574  	sge = sk_msg_elem(msg_opl, i);
575  	while (apply && sge->length) {
576  		if (sge->length > apply) {
577  			u32 len = sge->length - apply;
578  
579  			get_page(sg_page(sge));
580  			sg_set_page(&tmp, sg_page(sge), len,
581  				    sge->offset + apply);
582  			sge->length = apply;
583  			bytes += apply;
584  			apply = 0;
585  		} else {
586  			apply -= sge->length;
587  			bytes += sge->length;
588  		}
589  
590  		sk_msg_iter_var_next(i);
591  		if (i == msg_opl->sg.end)
592  			break;
593  		sge = sk_msg_elem(msg_opl, i);
594  	}
595  
596  	msg_opl->sg.end = i;
597  	msg_opl->sg.curr = i;
598  	msg_opl->sg.copybreak = 0;
599  	msg_opl->apply_bytes = 0;
600  	msg_opl->sg.size = bytes;
601  
602  	msg_npl = &new->msg_plaintext;
603  	msg_npl->apply_bytes = apply;
604  	msg_npl->sg.size = orig_size - bytes;
605  
606  	j = msg_npl->sg.start;
607  	nsge = sk_msg_elem(msg_npl, j);
608  	if (tmp.length) {
609  		memcpy(nsge, &tmp, sizeof(*nsge));
610  		sk_msg_iter_var_next(j);
611  		nsge = sk_msg_elem(msg_npl, j);
612  	}
613  
614  	osge = sk_msg_elem(msg_opl, i);
615  	while (osge->length) {
616  		memcpy(nsge, osge, sizeof(*nsge));
617  		sg_unmark_end(nsge);
618  		sk_msg_iter_var_next(i);
619  		sk_msg_iter_var_next(j);
620  		if (i == *orig_end)
621  			break;
622  		osge = sk_msg_elem(msg_opl, i);
623  		nsge = sk_msg_elem(msg_npl, j);
624  	}
625  
626  	msg_npl->sg.end = j;
627  	msg_npl->sg.curr = j;
628  	msg_npl->sg.copybreak = 0;
629  
630  	*to = new;
631  	return 0;
632  }
633  
634  static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
635  				  struct tls_rec *from, u32 orig_end)
636  {
637  	struct sk_msg *msg_npl = &from->msg_plaintext;
638  	struct sk_msg *msg_opl = &to->msg_plaintext;
639  	struct scatterlist *osge, *nsge;
640  	u32 i, j;
641  
642  	i = msg_opl->sg.end;
643  	sk_msg_iter_var_prev(i);
644  	j = msg_npl->sg.start;
645  
646  	osge = sk_msg_elem(msg_opl, i);
647  	nsge = sk_msg_elem(msg_npl, j);
648  
649  	if (sg_page(osge) == sg_page(nsge) &&
650  	    osge->offset + osge->length == nsge->offset) {
651  		osge->length += nsge->length;
652  		put_page(sg_page(nsge));
653  	}
654  
655  	msg_opl->sg.end = orig_end;
656  	msg_opl->sg.curr = orig_end;
657  	msg_opl->sg.copybreak = 0;
658  	msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
659  	msg_opl->sg.size += msg_npl->sg.size;
660  
661  	sk_msg_free(sk, &to->msg_encrypted);
662  	sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
663  
664  	kfree(from);
665  }
666  
667  static int tls_push_record(struct sock *sk, int flags,
668  			   unsigned char record_type)
669  {
670  	struct tls_context *tls_ctx = tls_get_ctx(sk);
671  	struct tls_prot_info *prot = &tls_ctx->prot_info;
672  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
673  	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
674  	u32 i, split_point, orig_end;
675  	struct sk_msg *msg_pl, *msg_en;
676  	struct aead_request *req;
677  	bool split;
678  	int rc;
679  
680  	if (!rec)
681  		return 0;
682  
683  	msg_pl = &rec->msg_plaintext;
684  	msg_en = &rec->msg_encrypted;
685  
686  	split_point = msg_pl->apply_bytes;
687  	split = split_point && split_point < msg_pl->sg.size;
688  	if (unlikely((!split &&
689  		      msg_pl->sg.size +
690  		      prot->overhead_size > msg_en->sg.size) ||
691  		     (split &&
692  		      split_point +
693  		      prot->overhead_size > msg_en->sg.size))) {
694  		split = true;
695  		split_point = msg_en->sg.size;
696  	}
697  	if (split) {
698  		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
699  					   split_point, prot->overhead_size,
700  					   &orig_end);
701  		if (rc < 0)
702  			return rc;
703  		/* This can happen if above tls_split_open_record allocates
704  		 * a single large encryption buffer instead of two smaller
705  		 * ones. In this case adjust pointers and continue without
706  		 * split.
707  		 */
708  		if (!msg_pl->sg.size) {
709  			tls_merge_open_record(sk, rec, tmp, orig_end);
710  			msg_pl = &rec->msg_plaintext;
711  			msg_en = &rec->msg_encrypted;
712  			split = false;
713  		}
714  		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
715  			    prot->overhead_size);
716  	}
717  
718  	rec->tx_flags = flags;
719  	req = &rec->aead_req;
720  
721  	i = msg_pl->sg.end;
722  	sk_msg_iter_var_prev(i);
723  
724  	rec->content_type = record_type;
725  	if (prot->version == TLS_1_3_VERSION) {
726  		/* Add content type to end of message.  No padding added */
727  		sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
728  		sg_mark_end(&rec->sg_content_type);
729  		sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
730  			 &rec->sg_content_type);
731  	} else {
732  		sg_mark_end(sk_msg_elem(msg_pl, i));
733  	}
734  
735  	if (msg_pl->sg.end < msg_pl->sg.start) {
736  		sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
737  			 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
738  			 msg_pl->sg.data);
739  	}
740  
741  	i = msg_pl->sg.start;
742  	sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
743  
744  	i = msg_en->sg.end;
745  	sk_msg_iter_var_prev(i);
746  	sg_mark_end(sk_msg_elem(msg_en, i));
747  
748  	i = msg_en->sg.start;
749  	sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
750  
751  	tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
752  		     tls_ctx->tx.rec_seq, record_type, prot);
753  
754  	tls_fill_prepend(tls_ctx,
755  			 page_address(sg_page(&msg_en->sg.data[i])) +
756  			 msg_en->sg.data[i].offset,
757  			 msg_pl->sg.size + prot->tail_size,
758  			 record_type);
759  
760  	tls_ctx->pending_open_record_frags = false;
761  
762  	rc = tls_do_encryption(sk, tls_ctx, ctx, req,
763  			       msg_pl->sg.size + prot->tail_size, i);
764  	if (rc < 0) {
765  		if (rc != -EINPROGRESS) {
766  			tls_err_abort(sk, -EBADMSG);
767  			if (split) {
768  				tls_ctx->pending_open_record_frags = true;
769  				tls_merge_open_record(sk, rec, tmp, orig_end);
770  			}
771  		}
772  		ctx->async_capable = 1;
773  		return rc;
774  	} else if (split) {
775  		msg_pl = &tmp->msg_plaintext;
776  		msg_en = &tmp->msg_encrypted;
777  		sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
778  		tls_ctx->pending_open_record_frags = true;
779  		ctx->open_rec = tmp;
780  	}
781  
782  	return tls_tx_records(sk, flags);
783  }
784  
785  static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
786  			       bool full_record, u8 record_type,
787  			       ssize_t *copied, int flags)
788  {
789  	struct tls_context *tls_ctx = tls_get_ctx(sk);
790  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
791  	struct sk_msg msg_redir = { };
792  	struct sk_psock *psock;
793  	struct sock *sk_redir;
794  	struct tls_rec *rec;
795  	bool enospc, policy;
796  	int err = 0, send;
797  	u32 delta = 0;
798  
799  	policy = !(flags & MSG_SENDPAGE_NOPOLICY);
800  	psock = sk_psock_get(sk);
801  	if (!psock || !policy) {
802  		err = tls_push_record(sk, flags, record_type);
803  		if (err && sk->sk_err == EBADMSG) {
804  			*copied -= sk_msg_free(sk, msg);
805  			tls_free_open_rec(sk);
806  			err = -sk->sk_err;
807  		}
808  		if (psock)
809  			sk_psock_put(sk, psock);
810  		return err;
811  	}
812  more_data:
813  	enospc = sk_msg_full(msg);
814  	if (psock->eval == __SK_NONE) {
815  		delta = msg->sg.size;
816  		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
817  		delta -= msg->sg.size;
818  	}
819  	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
820  	    !enospc && !full_record) {
821  		err = -ENOSPC;
822  		goto out_err;
823  	}
824  	msg->cork_bytes = 0;
825  	send = msg->sg.size;
826  	if (msg->apply_bytes && msg->apply_bytes < send)
827  		send = msg->apply_bytes;
828  
829  	switch (psock->eval) {
830  	case __SK_PASS:
831  		err = tls_push_record(sk, flags, record_type);
832  		if (err && sk->sk_err == EBADMSG) {
833  			*copied -= sk_msg_free(sk, msg);
834  			tls_free_open_rec(sk);
835  			err = -sk->sk_err;
836  			goto out_err;
837  		}
838  		break;
839  	case __SK_REDIRECT:
840  		sk_redir = psock->sk_redir;
841  		memcpy(&msg_redir, msg, sizeof(*msg));
842  		if (msg->apply_bytes < send)
843  			msg->apply_bytes = 0;
844  		else
845  			msg->apply_bytes -= send;
846  		sk_msg_return_zero(sk, msg, send);
847  		msg->sg.size -= send;
848  		release_sock(sk);
849  		err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
850  		lock_sock(sk);
851  		if (err < 0) {
852  			*copied -= sk_msg_free_nocharge(sk, &msg_redir);
853  			msg->sg.size = 0;
854  		}
855  		if (msg->sg.size == 0)
856  			tls_free_open_rec(sk);
857  		break;
858  	case __SK_DROP:
859  	default:
860  		sk_msg_free_partial(sk, msg, send);
861  		if (msg->apply_bytes < send)
862  			msg->apply_bytes = 0;
863  		else
864  			msg->apply_bytes -= send;
865  		if (msg->sg.size == 0)
866  			tls_free_open_rec(sk);
867  		*copied -= (send + delta);
868  		err = -EACCES;
869  	}
870  
871  	if (likely(!err)) {
872  		bool reset_eval = !ctx->open_rec;
873  
874  		rec = ctx->open_rec;
875  		if (rec) {
876  			msg = &rec->msg_plaintext;
877  			if (!msg->apply_bytes)
878  				reset_eval = true;
879  		}
880  		if (reset_eval) {
881  			psock->eval = __SK_NONE;
882  			if (psock->sk_redir) {
883  				sock_put(psock->sk_redir);
884  				psock->sk_redir = NULL;
885  			}
886  		}
887  		if (rec)
888  			goto more_data;
889  	}
890   out_err:
891  	sk_psock_put(sk, psock);
892  	return err;
893  }
894  
895  static int tls_sw_push_pending_record(struct sock *sk, int flags)
896  {
897  	struct tls_context *tls_ctx = tls_get_ctx(sk);
898  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
899  	struct tls_rec *rec = ctx->open_rec;
900  	struct sk_msg *msg_pl;
901  	size_t copied;
902  
903  	if (!rec)
904  		return 0;
905  
906  	msg_pl = &rec->msg_plaintext;
907  	copied = msg_pl->sg.size;
908  	if (!copied)
909  		return 0;
910  
911  	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
912  				   &copied, flags);
913  }
914  
915  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
916  {
917  	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
918  	struct tls_context *tls_ctx = tls_get_ctx(sk);
919  	struct tls_prot_info *prot = &tls_ctx->prot_info;
920  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
921  	bool async_capable = ctx->async_capable;
922  	unsigned char record_type = TLS_RECORD_TYPE_DATA;
923  	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
924  	bool eor = !(msg->msg_flags & MSG_MORE);
925  	size_t try_to_copy;
926  	ssize_t copied = 0;
927  	struct sk_msg *msg_pl, *msg_en;
928  	struct tls_rec *rec;
929  	int required_size;
930  	int num_async = 0;
931  	bool full_record;
932  	int record_room;
933  	int num_zc = 0;
934  	int orig_size;
935  	int ret = 0;
936  	int pending;
937  
938  	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
939  			       MSG_CMSG_COMPAT))
940  		return -EOPNOTSUPP;
941  
942  	mutex_lock(&tls_ctx->tx_lock);
943  	lock_sock(sk);
944  
945  	if (unlikely(msg->msg_controllen)) {
946  		ret = tls_process_cmsg(sk, msg, &record_type);
947  		if (ret) {
948  			if (ret == -EINPROGRESS)
949  				num_async++;
950  			else if (ret != -EAGAIN)
951  				goto send_end;
952  		}
953  	}
954  
955  	while (msg_data_left(msg)) {
956  		if (sk->sk_err) {
957  			ret = -sk->sk_err;
958  			goto send_end;
959  		}
960  
961  		if (ctx->open_rec)
962  			rec = ctx->open_rec;
963  		else
964  			rec = ctx->open_rec = tls_get_rec(sk);
965  		if (!rec) {
966  			ret = -ENOMEM;
967  			goto send_end;
968  		}
969  
970  		msg_pl = &rec->msg_plaintext;
971  		msg_en = &rec->msg_encrypted;
972  
973  		orig_size = msg_pl->sg.size;
974  		full_record = false;
975  		try_to_copy = msg_data_left(msg);
976  		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
977  		if (try_to_copy >= record_room) {
978  			try_to_copy = record_room;
979  			full_record = true;
980  		}
981  
982  		required_size = msg_pl->sg.size + try_to_copy +
983  				prot->overhead_size;
984  
985  		if (!sk_stream_memory_free(sk))
986  			goto wait_for_sndbuf;
987  
988  alloc_encrypted:
989  		ret = tls_alloc_encrypted_msg(sk, required_size);
990  		if (ret) {
991  			if (ret != -ENOSPC)
992  				goto wait_for_memory;
993  
994  			/* Adjust try_to_copy according to the amount that was
995  			 * actually allocated. The difference is due
996  			 * to max sg elements limit
997  			 */
998  			try_to_copy -= required_size - msg_en->sg.size;
999  			full_record = true;
1000  		}
1001  
1002  		if (!is_kvec && (full_record || eor) && !async_capable) {
1003  			u32 first = msg_pl->sg.end;
1004  
1005  			ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1006  							msg_pl, try_to_copy);
1007  			if (ret)
1008  				goto fallback_to_reg_send;
1009  
1010  			num_zc++;
1011  			copied += try_to_copy;
1012  
1013  			sk_msg_sg_copy_set(msg_pl, first);
1014  			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1015  						  record_type, &copied,
1016  						  msg->msg_flags);
1017  			if (ret) {
1018  				if (ret == -EINPROGRESS)
1019  					num_async++;
1020  				else if (ret == -ENOMEM)
1021  					goto wait_for_memory;
1022  				else if (ctx->open_rec && ret == -ENOSPC)
1023  					goto rollback_iter;
1024  				else if (ret != -EAGAIN)
1025  					goto send_end;
1026  			}
1027  			continue;
1028  rollback_iter:
1029  			copied -= try_to_copy;
1030  			sk_msg_sg_copy_clear(msg_pl, first);
1031  			iov_iter_revert(&msg->msg_iter,
1032  					msg_pl->sg.size - orig_size);
1033  fallback_to_reg_send:
1034  			sk_msg_trim(sk, msg_pl, orig_size);
1035  		}
1036  
1037  		required_size = msg_pl->sg.size + try_to_copy;
1038  
1039  		ret = tls_clone_plaintext_msg(sk, required_size);
1040  		if (ret) {
1041  			if (ret != -ENOSPC)
1042  				goto send_end;
1043  
1044  			/* Adjust try_to_copy according to the amount that was
1045  			 * actually allocated. The difference is due
1046  			 * to max sg elements limit
1047  			 */
1048  			try_to_copy -= required_size - msg_pl->sg.size;
1049  			full_record = true;
1050  			sk_msg_trim(sk, msg_en,
1051  				    msg_pl->sg.size + prot->overhead_size);
1052  		}
1053  
1054  		if (try_to_copy) {
1055  			ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1056  						       msg_pl, try_to_copy);
1057  			if (ret < 0)
1058  				goto trim_sgl;
1059  		}
1060  
1061  		/* Open records defined only if successfully copied, otherwise
1062  		 * we would trim the sg but not reset the open record frags.
1063  		 */
1064  		tls_ctx->pending_open_record_frags = true;
1065  		copied += try_to_copy;
1066  		if (full_record || eor) {
1067  			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1068  						  record_type, &copied,
1069  						  msg->msg_flags);
1070  			if (ret) {
1071  				if (ret == -EINPROGRESS)
1072  					num_async++;
1073  				else if (ret == -ENOMEM)
1074  					goto wait_for_memory;
1075  				else if (ret != -EAGAIN) {
1076  					if (ret == -ENOSPC)
1077  						ret = 0;
1078  					goto send_end;
1079  				}
1080  			}
1081  		}
1082  
1083  		continue;
1084  
1085  wait_for_sndbuf:
1086  		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1087  wait_for_memory:
1088  		ret = sk_stream_wait_memory(sk, &timeo);
1089  		if (ret) {
1090  trim_sgl:
1091  			if (ctx->open_rec)
1092  				tls_trim_both_msgs(sk, orig_size);
1093  			goto send_end;
1094  		}
1095  
1096  		if (ctx->open_rec && msg_en->sg.size < required_size)
1097  			goto alloc_encrypted;
1098  	}
1099  
1100  	if (!num_async) {
1101  		goto send_end;
1102  	} else if (num_zc) {
1103  		/* Wait for pending encryptions to get completed */
1104  		spin_lock_bh(&ctx->encrypt_compl_lock);
1105  		ctx->async_notify = true;
1106  
1107  		pending = atomic_read(&ctx->encrypt_pending);
1108  		spin_unlock_bh(&ctx->encrypt_compl_lock);
1109  		if (pending)
1110  			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1111  		else
1112  			reinit_completion(&ctx->async_wait.completion);
1113  
1114  		/* There can be no concurrent accesses, since we have no
1115  		 * pending encrypt operations
1116  		 */
1117  		WRITE_ONCE(ctx->async_notify, false);
1118  
1119  		if (ctx->async_wait.err) {
1120  			ret = ctx->async_wait.err;
1121  			copied = 0;
1122  		}
1123  	}
1124  
1125  	/* Transmit if any encryptions have completed */
1126  	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1127  		cancel_delayed_work(&ctx->tx_work.work);
1128  		tls_tx_records(sk, msg->msg_flags);
1129  	}
1130  
1131  send_end:
1132  	ret = sk_stream_error(sk, msg->msg_flags, ret);
1133  
1134  	release_sock(sk);
1135  	mutex_unlock(&tls_ctx->tx_lock);
1136  	return copied > 0 ? copied : ret;
1137  }
1138  
1139  static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1140  			      int offset, size_t size, int flags)
1141  {
1142  	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1143  	struct tls_context *tls_ctx = tls_get_ctx(sk);
1144  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1145  	struct tls_prot_info *prot = &tls_ctx->prot_info;
1146  	unsigned char record_type = TLS_RECORD_TYPE_DATA;
1147  	struct sk_msg *msg_pl;
1148  	struct tls_rec *rec;
1149  	int num_async = 0;
1150  	ssize_t copied = 0;
1151  	bool full_record;
1152  	int record_room;
1153  	int ret = 0;
1154  	bool eor;
1155  
1156  	eor = !(flags & MSG_SENDPAGE_NOTLAST);
1157  	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1158  
1159  	/* Call the sk_stream functions to manage the sndbuf mem. */
1160  	while (size > 0) {
1161  		size_t copy, required_size;
1162  
1163  		if (sk->sk_err) {
1164  			ret = -sk->sk_err;
1165  			goto sendpage_end;
1166  		}
1167  
1168  		if (ctx->open_rec)
1169  			rec = ctx->open_rec;
1170  		else
1171  			rec = ctx->open_rec = tls_get_rec(sk);
1172  		if (!rec) {
1173  			ret = -ENOMEM;
1174  			goto sendpage_end;
1175  		}
1176  
1177  		msg_pl = &rec->msg_plaintext;
1178  
1179  		full_record = false;
1180  		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1181  		copy = size;
1182  		if (copy >= record_room) {
1183  			copy = record_room;
1184  			full_record = true;
1185  		}
1186  
1187  		required_size = msg_pl->sg.size + copy + prot->overhead_size;
1188  
1189  		if (!sk_stream_memory_free(sk))
1190  			goto wait_for_sndbuf;
1191  alloc_payload:
1192  		ret = tls_alloc_encrypted_msg(sk, required_size);
1193  		if (ret) {
1194  			if (ret != -ENOSPC)
1195  				goto wait_for_memory;
1196  
1197  			/* Adjust copy according to the amount that was
1198  			 * actually allocated. The difference is due
1199  			 * to max sg elements limit
1200  			 */
1201  			copy -= required_size - msg_pl->sg.size;
1202  			full_record = true;
1203  		}
1204  
1205  		sk_msg_page_add(msg_pl, page, copy, offset);
1206  		sk_mem_charge(sk, copy);
1207  
1208  		offset += copy;
1209  		size -= copy;
1210  		copied += copy;
1211  
1212  		tls_ctx->pending_open_record_frags = true;
1213  		if (full_record || eor || sk_msg_full(msg_pl)) {
1214  			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1215  						  record_type, &copied, flags);
1216  			if (ret) {
1217  				if (ret == -EINPROGRESS)
1218  					num_async++;
1219  				else if (ret == -ENOMEM)
1220  					goto wait_for_memory;
1221  				else if (ret != -EAGAIN) {
1222  					if (ret == -ENOSPC)
1223  						ret = 0;
1224  					goto sendpage_end;
1225  				}
1226  			}
1227  		}
1228  		continue;
1229  wait_for_sndbuf:
1230  		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1231  wait_for_memory:
1232  		ret = sk_stream_wait_memory(sk, &timeo);
1233  		if (ret) {
1234  			if (ctx->open_rec)
1235  				tls_trim_both_msgs(sk, msg_pl->sg.size);
1236  			goto sendpage_end;
1237  		}
1238  
1239  		if (ctx->open_rec)
1240  			goto alloc_payload;
1241  	}
1242  
1243  	if (num_async) {
1244  		/* Transmit if any encryptions have completed */
1245  		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1246  			cancel_delayed_work(&ctx->tx_work.work);
1247  			tls_tx_records(sk, flags);
1248  		}
1249  	}
1250  sendpage_end:
1251  	ret = sk_stream_error(sk, flags, ret);
1252  	return copied > 0 ? copied : ret;
1253  }
1254  
1255  int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1256  			   int offset, size_t size, int flags)
1257  {
1258  	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1259  		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1260  		      MSG_NO_SHARED_FRAGS))
1261  		return -EOPNOTSUPP;
1262  
1263  	return tls_sw_do_sendpage(sk, page, offset, size, flags);
1264  }
1265  
1266  int tls_sw_sendpage(struct sock *sk, struct page *page,
1267  		    int offset, size_t size, int flags)
1268  {
1269  	struct tls_context *tls_ctx = tls_get_ctx(sk);
1270  	int ret;
1271  
1272  	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1273  		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1274  		return -EOPNOTSUPP;
1275  
1276  	mutex_lock(&tls_ctx->tx_lock);
1277  	lock_sock(sk);
1278  	ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1279  	release_sock(sk);
1280  	mutex_unlock(&tls_ctx->tx_lock);
1281  	return ret;
1282  }
1283  
1284  static int
1285  tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
1286  		bool released)
1287  {
1288  	struct tls_context *tls_ctx = tls_get_ctx(sk);
1289  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1290  	DEFINE_WAIT_FUNC(wait, woken_wake_function);
1291  	long timeo;
1292  
1293  	timeo = sock_rcvtimeo(sk, nonblock);
1294  
1295  	while (!tls_strp_msg_ready(ctx)) {
1296  		if (!sk_psock_queue_empty(psock))
1297  			return 0;
1298  
1299  		if (sk->sk_err)
1300  			return sock_error(sk);
1301  
1302  		if (!skb_queue_empty(&sk->sk_receive_queue)) {
1303  			tls_strp_check_rcv(&ctx->strp);
1304  			if (tls_strp_msg_ready(ctx))
1305  				break;
1306  		}
1307  
1308  		if (sk->sk_shutdown & RCV_SHUTDOWN)
1309  			return 0;
1310  
1311  		if (sock_flag(sk, SOCK_DONE))
1312  			return 0;
1313  
1314  		if (!timeo)
1315  			return -EAGAIN;
1316  
1317  		released = true;
1318  		add_wait_queue(sk_sleep(sk), &wait);
1319  		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1320  		sk_wait_event(sk, &timeo,
1321  			      tls_strp_msg_ready(ctx) ||
1322  			      !sk_psock_queue_empty(psock),
1323  			      &wait);
1324  		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1325  		remove_wait_queue(sk_sleep(sk), &wait);
1326  
1327  		/* Handle signals */
1328  		if (signal_pending(current))
1329  			return sock_intr_errno(timeo);
1330  	}
1331  
1332  	tls_strp_msg_load(&ctx->strp, released);
1333  
1334  	return 1;
1335  }
1336  
1337  static int tls_setup_from_iter(struct iov_iter *from,
1338  			       int length, int *pages_used,
1339  			       struct scatterlist *to,
1340  			       int to_max_pages)
1341  {
1342  	int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1343  	struct page *pages[MAX_SKB_FRAGS];
1344  	unsigned int size = 0;
1345  	ssize_t copied, use;
1346  	size_t offset;
1347  
1348  	while (length > 0) {
1349  		i = 0;
1350  		maxpages = to_max_pages - num_elem;
1351  		if (maxpages == 0) {
1352  			rc = -EFAULT;
1353  			goto out;
1354  		}
1355  		copied = iov_iter_get_pages2(from, pages,
1356  					    length,
1357  					    maxpages, &offset);
1358  		if (copied <= 0) {
1359  			rc = -EFAULT;
1360  			goto out;
1361  		}
1362  
1363  		length -= copied;
1364  		size += copied;
1365  		while (copied) {
1366  			use = min_t(int, copied, PAGE_SIZE - offset);
1367  
1368  			sg_set_page(&to[num_elem],
1369  				    pages[i], use, offset);
1370  			sg_unmark_end(&to[num_elem]);
1371  			/* We do not uncharge memory from this API */
1372  
1373  			offset = 0;
1374  			copied -= use;
1375  
1376  			i++;
1377  			num_elem++;
1378  		}
1379  	}
1380  	/* Mark the end in the last sg entry if newly added */
1381  	if (num_elem > *pages_used)
1382  		sg_mark_end(&to[num_elem - 1]);
1383  out:
1384  	if (rc)
1385  		iov_iter_revert(from, size);
1386  	*pages_used = num_elem;
1387  
1388  	return rc;
1389  }
1390  
1391  static struct sk_buff *
1392  tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
1393  		     unsigned int full_len)
1394  {
1395  	struct strp_msg *clr_rxm;
1396  	struct sk_buff *clr_skb;
1397  	int err;
1398  
1399  	clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
1400  				       &err, sk->sk_allocation);
1401  	if (!clr_skb)
1402  		return NULL;
1403  
1404  	skb_copy_header(clr_skb, skb);
1405  	clr_skb->len = full_len;
1406  	clr_skb->data_len = full_len;
1407  
1408  	clr_rxm = strp_msg(clr_skb);
1409  	clr_rxm->offset = 0;
1410  
1411  	return clr_skb;
1412  }
1413  
1414  /* Decrypt handlers
1415   *
1416   * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers.
1417   * They must transform the darg in/out argument are as follows:
1418   *       |          Input            |         Output
1419   * -------------------------------------------------------------------
1420   *    zc | Zero-copy decrypt allowed | Zero-copy performed
1421   * async | Async decrypt allowed     | Async crypto used / in progress
1422   *   skb |            *              | Output skb
1423   *
1424   * If ZC decryption was performed darg.skb will point to the input skb.
1425   */
1426  
1427  /* This function decrypts the input skb into either out_iov or in out_sg
1428   * or in skb buffers itself. The input parameter 'darg->zc' indicates if
1429   * zero-copy mode needs to be tried or not. With zero-copy mode, either
1430   * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1431   * NULL, then the decryption happens inside skb buffers itself, i.e.
1432   * zero-copy gets disabled and 'darg->zc' is updated.
1433   */
1434  static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
1435  			  struct scatterlist *out_sg,
1436  			  struct tls_decrypt_arg *darg)
1437  {
1438  	struct tls_context *tls_ctx = tls_get_ctx(sk);
1439  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1440  	struct tls_prot_info *prot = &tls_ctx->prot_info;
1441  	int n_sgin, n_sgout, aead_size, err, pages = 0;
1442  	struct sk_buff *skb = tls_strp_msg(ctx);
1443  	const struct strp_msg *rxm = strp_msg(skb);
1444  	const struct tls_msg *tlm = tls_msg(skb);
1445  	struct aead_request *aead_req;
1446  	struct scatterlist *sgin = NULL;
1447  	struct scatterlist *sgout = NULL;
1448  	const int data_len = rxm->full_len - prot->overhead_size;
1449  	int tail_pages = !!prot->tail_size;
1450  	struct tls_decrypt_ctx *dctx;
1451  	struct sk_buff *clear_skb;
1452  	int iv_offset = 0;
1453  	u8 *mem;
1454  
1455  	n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1456  			 rxm->full_len - prot->prepend_size);
1457  	if (n_sgin < 1)
1458  		return n_sgin ?: -EBADMSG;
1459  
1460  	if (darg->zc && (out_iov || out_sg)) {
1461  		clear_skb = NULL;
1462  
1463  		if (out_iov)
1464  			n_sgout = 1 + tail_pages +
1465  				iov_iter_npages_cap(out_iov, INT_MAX, data_len);
1466  		else
1467  			n_sgout = sg_nents(out_sg);
1468  	} else {
1469  		darg->zc = false;
1470  
1471  		clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
1472  		if (!clear_skb)
1473  			return -ENOMEM;
1474  
1475  		n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
1476  	}
1477  
1478  	/* Increment to accommodate AAD */
1479  	n_sgin = n_sgin + 1;
1480  
1481  	/* Allocate a single block of memory which contains
1482  	 *   aead_req || tls_decrypt_ctx.
1483  	 * Both structs are variable length.
1484  	 */
1485  	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1486  	mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
1487  		      sk->sk_allocation);
1488  	if (!mem) {
1489  		err = -ENOMEM;
1490  		goto exit_free_skb;
1491  	}
1492  
1493  	/* Segment the allocated memory */
1494  	aead_req = (struct aead_request *)mem;
1495  	dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
1496  	sgin = &dctx->sg[0];
1497  	sgout = &dctx->sg[n_sgin];
1498  
1499  	/* For CCM based ciphers, first byte of nonce+iv is a constant */
1500  	switch (prot->cipher_type) {
1501  	case TLS_CIPHER_AES_CCM_128:
1502  		dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
1503  		iv_offset = 1;
1504  		break;
1505  	case TLS_CIPHER_SM4_CCM:
1506  		dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
1507  		iv_offset = 1;
1508  		break;
1509  	}
1510  
1511  	/* Prepare IV */
1512  	if (prot->version == TLS_1_3_VERSION ||
1513  	    prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
1514  		memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
1515  		       prot->iv_size + prot->salt_size);
1516  	} else {
1517  		err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1518  				    &dctx->iv[iv_offset] + prot->salt_size,
1519  				    prot->iv_size);
1520  		if (err < 0)
1521  			goto exit_free;
1522  		memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
1523  	}
1524  	tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
1525  
1526  	/* Prepare AAD */
1527  	tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
1528  		     prot->tail_size,
1529  		     tls_ctx->rx.rec_seq, tlm->control, prot);
1530  
1531  	/* Prepare sgin */
1532  	sg_init_table(sgin, n_sgin);
1533  	sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
1534  	err = skb_to_sgvec(skb, &sgin[1],
1535  			   rxm->offset + prot->prepend_size,
1536  			   rxm->full_len - prot->prepend_size);
1537  	if (err < 0)
1538  		goto exit_free;
1539  
1540  	if (clear_skb) {
1541  		sg_init_table(sgout, n_sgout);
1542  		sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1543  
1544  		err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
1545  				   data_len + prot->tail_size);
1546  		if (err < 0)
1547  			goto exit_free;
1548  	} else if (out_iov) {
1549  		sg_init_table(sgout, n_sgout);
1550  		sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1551  
1552  		err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
1553  					  (n_sgout - 1 - tail_pages));
1554  		if (err < 0)
1555  			goto exit_free_pages;
1556  
1557  		if (prot->tail_size) {
1558  			sg_unmark_end(&sgout[pages]);
1559  			sg_set_buf(&sgout[pages + 1], &dctx->tail,
1560  				   prot->tail_size);
1561  			sg_mark_end(&sgout[pages + 1]);
1562  		}
1563  	} else if (out_sg) {
1564  		memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1565  	}
1566  
1567  	/* Prepare and submit AEAD request */
1568  	err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
1569  				data_len + prot->tail_size, aead_req, darg);
1570  	if (err)
1571  		goto exit_free_pages;
1572  
1573  	darg->skb = clear_skb ?: tls_strp_msg(ctx);
1574  	clear_skb = NULL;
1575  
1576  	if (unlikely(darg->async)) {
1577  		err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
1578  		if (err)
1579  			__skb_queue_tail(&ctx->async_hold, darg->skb);
1580  		return err;
1581  	}
1582  
1583  	if (prot->tail_size)
1584  		darg->tail = dctx->tail;
1585  
1586  exit_free_pages:
1587  	/* Release the pages in case iov was mapped to pages */
1588  	for (; pages > 0; pages--)
1589  		put_page(sg_page(&sgout[pages]));
1590  exit_free:
1591  	kfree(mem);
1592  exit_free_skb:
1593  	consume_skb(clear_skb);
1594  	return err;
1595  }
1596  
1597  static int
1598  tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx,
1599  	       struct msghdr *msg, struct tls_decrypt_arg *darg)
1600  {
1601  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1602  	struct tls_prot_info *prot = &tls_ctx->prot_info;
1603  	struct strp_msg *rxm;
1604  	int pad, err;
1605  
1606  	err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg);
1607  	if (err < 0) {
1608  		if (err == -EBADMSG)
1609  			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
1610  		return err;
1611  	}
1612  	/* keep going even for ->async, the code below is TLS 1.3 */
1613  
1614  	/* If opportunistic TLS 1.3 ZC failed retry without ZC */
1615  	if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
1616  		     darg->tail != TLS_RECORD_TYPE_DATA)) {
1617  		darg->zc = false;
1618  		if (!darg->tail)
1619  			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
1620  		TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
1621  		return tls_decrypt_sw(sk, tls_ctx, msg, darg);
1622  	}
1623  
1624  	pad = tls_padding_length(prot, darg->skb, darg);
1625  	if (pad < 0) {
1626  		if (darg->skb != tls_strp_msg(ctx))
1627  			consume_skb(darg->skb);
1628  		return pad;
1629  	}
1630  
1631  	rxm = strp_msg(darg->skb);
1632  	rxm->full_len -= pad;
1633  
1634  	return 0;
1635  }
1636  
1637  static int
1638  tls_decrypt_device(struct sock *sk, struct msghdr *msg,
1639  		   struct tls_context *tls_ctx, struct tls_decrypt_arg *darg)
1640  {
1641  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1642  	struct tls_prot_info *prot = &tls_ctx->prot_info;
1643  	struct strp_msg *rxm;
1644  	int pad, err;
1645  
1646  	if (tls_ctx->rx_conf != TLS_HW)
1647  		return 0;
1648  
1649  	err = tls_device_decrypted(sk, tls_ctx);
1650  	if (err <= 0)
1651  		return err;
1652  
1653  	pad = tls_padding_length(prot, tls_strp_msg(ctx), darg);
1654  	if (pad < 0)
1655  		return pad;
1656  
1657  	darg->async = false;
1658  	darg->skb = tls_strp_msg(ctx);
1659  	/* ->zc downgrade check, in case TLS 1.3 gets here */
1660  	darg->zc &= !(prot->version == TLS_1_3_VERSION &&
1661  		      tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA);
1662  
1663  	rxm = strp_msg(darg->skb);
1664  	rxm->full_len -= pad;
1665  
1666  	if (!darg->zc) {
1667  		/* Non-ZC case needs a real skb */
1668  		darg->skb = tls_strp_msg_detach(ctx);
1669  		if (!darg->skb)
1670  			return -ENOMEM;
1671  	} else {
1672  		unsigned int off, len;
1673  
1674  		/* In ZC case nobody cares about the output skb.
1675  		 * Just copy the data here. Note the skb is not fully trimmed.
1676  		 */
1677  		off = rxm->offset + prot->prepend_size;
1678  		len = rxm->full_len - prot->overhead_size;
1679  
1680  		err = skb_copy_datagram_msg(darg->skb, off, msg, len);
1681  		if (err)
1682  			return err;
1683  	}
1684  	return 1;
1685  }
1686  
1687  static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
1688  			     struct tls_decrypt_arg *darg)
1689  {
1690  	struct tls_context *tls_ctx = tls_get_ctx(sk);
1691  	struct tls_prot_info *prot = &tls_ctx->prot_info;
1692  	struct strp_msg *rxm;
1693  	int err;
1694  
1695  	err = tls_decrypt_device(sk, msg, tls_ctx, darg);
1696  	if (!err)
1697  		err = tls_decrypt_sw(sk, tls_ctx, msg, darg);
1698  	if (err < 0)
1699  		return err;
1700  
1701  	rxm = strp_msg(darg->skb);
1702  	rxm->offset += prot->prepend_size;
1703  	rxm->full_len -= prot->overhead_size;
1704  	tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1705  
1706  	return 0;
1707  }
1708  
1709  int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
1710  {
1711  	struct tls_decrypt_arg darg = { .zc = true, };
1712  
1713  	return tls_decrypt_sg(sk, NULL, sgout, &darg);
1714  }
1715  
1716  static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
1717  				   u8 *control)
1718  {
1719  	int err;
1720  
1721  	if (!*control) {
1722  		*control = tlm->control;
1723  		if (!*control)
1724  			return -EBADMSG;
1725  
1726  		err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1727  			       sizeof(*control), control);
1728  		if (*control != TLS_RECORD_TYPE_DATA) {
1729  			if (err || msg->msg_flags & MSG_CTRUNC)
1730  				return -EIO;
1731  		}
1732  	} else if (*control != tlm->control) {
1733  		return 0;
1734  	}
1735  
1736  	return 1;
1737  }
1738  
1739  static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
1740  {
1741  	tls_strp_msg_done(&ctx->strp);
1742  }
1743  
1744  /* This function traverses the rx_list in tls receive context to copies the
1745   * decrypted records into the buffer provided by caller zero copy is not
1746   * true. Further, the records are removed from the rx_list if it is not a peek
1747   * case and the record has been consumed completely.
1748   */
1749  static int process_rx_list(struct tls_sw_context_rx *ctx,
1750  			   struct msghdr *msg,
1751  			   u8 *control,
1752  			   size_t skip,
1753  			   size_t len,
1754  			   bool is_peek)
1755  {
1756  	struct sk_buff *skb = skb_peek(&ctx->rx_list);
1757  	struct tls_msg *tlm;
1758  	ssize_t copied = 0;
1759  	int err;
1760  
1761  	while (skip && skb) {
1762  		struct strp_msg *rxm = strp_msg(skb);
1763  		tlm = tls_msg(skb);
1764  
1765  		err = tls_record_content_type(msg, tlm, control);
1766  		if (err <= 0)
1767  			goto out;
1768  
1769  		if (skip < rxm->full_len)
1770  			break;
1771  
1772  		skip = skip - rxm->full_len;
1773  		skb = skb_peek_next(skb, &ctx->rx_list);
1774  	}
1775  
1776  	while (len && skb) {
1777  		struct sk_buff *next_skb;
1778  		struct strp_msg *rxm = strp_msg(skb);
1779  		int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1780  
1781  		tlm = tls_msg(skb);
1782  
1783  		err = tls_record_content_type(msg, tlm, control);
1784  		if (err <= 0)
1785  			goto out;
1786  
1787  		err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1788  					    msg, chunk);
1789  		if (err < 0)
1790  			goto out;
1791  
1792  		len = len - chunk;
1793  		copied = copied + chunk;
1794  
1795  		/* Consume the data from record if it is non-peek case*/
1796  		if (!is_peek) {
1797  			rxm->offset = rxm->offset + chunk;
1798  			rxm->full_len = rxm->full_len - chunk;
1799  
1800  			/* Return if there is unconsumed data in the record */
1801  			if (rxm->full_len - skip)
1802  				break;
1803  		}
1804  
1805  		/* The remaining skip-bytes must lie in 1st record in rx_list.
1806  		 * So from the 2nd record, 'skip' should be 0.
1807  		 */
1808  		skip = 0;
1809  
1810  		if (msg)
1811  			msg->msg_flags |= MSG_EOR;
1812  
1813  		next_skb = skb_peek_next(skb, &ctx->rx_list);
1814  
1815  		if (!is_peek) {
1816  			__skb_unlink(skb, &ctx->rx_list);
1817  			consume_skb(skb);
1818  		}
1819  
1820  		skb = next_skb;
1821  	}
1822  	err = 0;
1823  
1824  out:
1825  	return copied ? : err;
1826  }
1827  
1828  static bool
1829  tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
1830  		       size_t len_left, size_t decrypted, ssize_t done,
1831  		       size_t *flushed_at)
1832  {
1833  	size_t max_rec;
1834  
1835  	if (len_left <= decrypted)
1836  		return false;
1837  
1838  	max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
1839  	if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
1840  		return false;
1841  
1842  	*flushed_at = done;
1843  	return sk_flush_backlog(sk);
1844  }
1845  
1846  static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
1847  			      bool nonblock)
1848  {
1849  	long timeo;
1850  	int err;
1851  
1852  	lock_sock(sk);
1853  
1854  	timeo = sock_rcvtimeo(sk, nonblock);
1855  
1856  	while (unlikely(ctx->reader_present)) {
1857  		DEFINE_WAIT_FUNC(wait, woken_wake_function);
1858  
1859  		ctx->reader_contended = 1;
1860  
1861  		add_wait_queue(&ctx->wq, &wait);
1862  		sk_wait_event(sk, &timeo,
1863  			      !READ_ONCE(ctx->reader_present), &wait);
1864  		remove_wait_queue(&ctx->wq, &wait);
1865  
1866  		if (timeo <= 0) {
1867  			err = -EAGAIN;
1868  			goto err_unlock;
1869  		}
1870  		if (signal_pending(current)) {
1871  			err = sock_intr_errno(timeo);
1872  			goto err_unlock;
1873  		}
1874  	}
1875  
1876  	WRITE_ONCE(ctx->reader_present, 1);
1877  
1878  	return 0;
1879  
1880  err_unlock:
1881  	release_sock(sk);
1882  	return err;
1883  }
1884  
1885  static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
1886  {
1887  	if (unlikely(ctx->reader_contended)) {
1888  		if (wq_has_sleeper(&ctx->wq))
1889  			wake_up(&ctx->wq);
1890  		else
1891  			ctx->reader_contended = 0;
1892  
1893  		WARN_ON_ONCE(!ctx->reader_present);
1894  	}
1895  
1896  	WRITE_ONCE(ctx->reader_present, 0);
1897  	release_sock(sk);
1898  }
1899  
1900  int tls_sw_recvmsg(struct sock *sk,
1901  		   struct msghdr *msg,
1902  		   size_t len,
1903  		   int flags,
1904  		   int *addr_len)
1905  {
1906  	struct tls_context *tls_ctx = tls_get_ctx(sk);
1907  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1908  	struct tls_prot_info *prot = &tls_ctx->prot_info;
1909  	ssize_t decrypted = 0, async_copy_bytes = 0;
1910  	struct sk_psock *psock;
1911  	unsigned char control = 0;
1912  	size_t flushed_at = 0;
1913  	struct strp_msg *rxm;
1914  	struct tls_msg *tlm;
1915  	ssize_t copied = 0;
1916  	bool async = false;
1917  	int target, err;
1918  	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1919  	bool is_peek = flags & MSG_PEEK;
1920  	bool released = true;
1921  	bool bpf_strp_enabled;
1922  	bool zc_capable;
1923  
1924  	if (unlikely(flags & MSG_ERRQUEUE))
1925  		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1926  
1927  	psock = sk_psock_get(sk);
1928  	err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
1929  	if (err < 0)
1930  		return err;
1931  	bpf_strp_enabled = sk_psock_strp_enabled(psock);
1932  
1933  	/* If crypto failed the connection is broken */
1934  	err = ctx->async_wait.err;
1935  	if (err)
1936  		goto end;
1937  
1938  	/* Process pending decrypted records. It must be non-zero-copy */
1939  	err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
1940  	if (err < 0)
1941  		goto end;
1942  
1943  	copied = err;
1944  	if (len <= copied)
1945  		goto end;
1946  
1947  	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1948  	len = len - copied;
1949  
1950  	zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
1951  		ctx->zc_capable;
1952  	decrypted = 0;
1953  	while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) {
1954  		struct tls_decrypt_arg darg;
1955  		int to_decrypt, chunk;
1956  
1957  		err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT,
1958  				      released);
1959  		if (err <= 0) {
1960  			if (psock) {
1961  				chunk = sk_msg_recvmsg(sk, psock, msg, len,
1962  						       flags);
1963  				if (chunk > 0) {
1964  					decrypted += chunk;
1965  					len -= chunk;
1966  					continue;
1967  				}
1968  			}
1969  			goto recv_end;
1970  		}
1971  
1972  		memset(&darg.inargs, 0, sizeof(darg.inargs));
1973  
1974  		rxm = strp_msg(tls_strp_msg(ctx));
1975  		tlm = tls_msg(tls_strp_msg(ctx));
1976  
1977  		to_decrypt = rxm->full_len - prot->overhead_size;
1978  
1979  		if (zc_capable && to_decrypt <= len &&
1980  		    tlm->control == TLS_RECORD_TYPE_DATA)
1981  			darg.zc = true;
1982  
1983  		/* Do not use async mode if record is non-data */
1984  		if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1985  			darg.async = ctx->async_capable;
1986  		else
1987  			darg.async = false;
1988  
1989  		err = tls_rx_one_record(sk, msg, &darg);
1990  		if (err < 0) {
1991  			tls_err_abort(sk, -EBADMSG);
1992  			goto recv_end;
1993  		}
1994  
1995  		async |= darg.async;
1996  
1997  		/* If the type of records being processed is not known yet,
1998  		 * set it to record type just dequeued. If it is already known,
1999  		 * but does not match the record type just dequeued, go to end.
2000  		 * We always get record type here since for tls1.2, record type
2001  		 * is known just after record is dequeued from stream parser.
2002  		 * For tls1.3, we disable async.
2003  		 */
2004  		err = tls_record_content_type(msg, tls_msg(darg.skb), &control);
2005  		if (err <= 0) {
2006  			DEBUG_NET_WARN_ON_ONCE(darg.zc);
2007  			tls_rx_rec_done(ctx);
2008  put_on_rx_list_err:
2009  			__skb_queue_tail(&ctx->rx_list, darg.skb);
2010  			goto recv_end;
2011  		}
2012  
2013  		/* periodically flush backlog, and feed strparser */
2014  		released = tls_read_flush_backlog(sk, prot, len, to_decrypt,
2015  						  decrypted + copied,
2016  						  &flushed_at);
2017  
2018  		/* TLS 1.3 may have updated the length by more than overhead */
2019  		rxm = strp_msg(darg.skb);
2020  		chunk = rxm->full_len;
2021  		tls_rx_rec_done(ctx);
2022  
2023  		if (!darg.zc) {
2024  			bool partially_consumed = chunk > len;
2025  			struct sk_buff *skb = darg.skb;
2026  
2027  			DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor);
2028  
2029  			if (async) {
2030  				/* TLS 1.2-only, to_decrypt must be text len */
2031  				chunk = min_t(int, to_decrypt, len);
2032  				async_copy_bytes += chunk;
2033  put_on_rx_list:
2034  				decrypted += chunk;
2035  				len -= chunk;
2036  				__skb_queue_tail(&ctx->rx_list, skb);
2037  				continue;
2038  			}
2039  
2040  			if (bpf_strp_enabled) {
2041  				released = true;
2042  				err = sk_psock_tls_strp_read(psock, skb);
2043  				if (err != __SK_PASS) {
2044  					rxm->offset = rxm->offset + rxm->full_len;
2045  					rxm->full_len = 0;
2046  					if (err == __SK_DROP)
2047  						consume_skb(skb);
2048  					continue;
2049  				}
2050  			}
2051  
2052  			if (partially_consumed)
2053  				chunk = len;
2054  
2055  			err = skb_copy_datagram_msg(skb, rxm->offset,
2056  						    msg, chunk);
2057  			if (err < 0)
2058  				goto put_on_rx_list_err;
2059  
2060  			if (is_peek)
2061  				goto put_on_rx_list;
2062  
2063  			if (partially_consumed) {
2064  				rxm->offset += chunk;
2065  				rxm->full_len -= chunk;
2066  				goto put_on_rx_list;
2067  			}
2068  
2069  			consume_skb(skb);
2070  		}
2071  
2072  		decrypted += chunk;
2073  		len -= chunk;
2074  
2075  		/* Return full control message to userspace before trying
2076  		 * to parse another message type
2077  		 */
2078  		msg->msg_flags |= MSG_EOR;
2079  		if (control != TLS_RECORD_TYPE_DATA)
2080  			break;
2081  	}
2082  
2083  recv_end:
2084  	if (async) {
2085  		int ret, pending;
2086  
2087  		/* Wait for all previously submitted records to be decrypted */
2088  		spin_lock_bh(&ctx->decrypt_compl_lock);
2089  		reinit_completion(&ctx->async_wait.completion);
2090  		pending = atomic_read(&ctx->decrypt_pending);
2091  		spin_unlock_bh(&ctx->decrypt_compl_lock);
2092  		ret = 0;
2093  		if (pending)
2094  			ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2095  		__skb_queue_purge(&ctx->async_hold);
2096  
2097  		if (ret) {
2098  			if (err >= 0 || err == -EINPROGRESS)
2099  				err = ret;
2100  			decrypted = 0;
2101  			goto end;
2102  		}
2103  
2104  		/* Drain records from the rx_list & copy if required */
2105  		if (is_peek || is_kvec)
2106  			err = process_rx_list(ctx, msg, &control, copied,
2107  					      decrypted, is_peek);
2108  		else
2109  			err = process_rx_list(ctx, msg, &control, 0,
2110  					      async_copy_bytes, is_peek);
2111  		decrypted = max(err, 0);
2112  	}
2113  
2114  	copied += decrypted;
2115  
2116  end:
2117  	tls_rx_reader_unlock(sk, ctx);
2118  	if (psock)
2119  		sk_psock_put(sk, psock);
2120  	return copied ? : err;
2121  }
2122  
2123  ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
2124  			   struct pipe_inode_info *pipe,
2125  			   size_t len, unsigned int flags)
2126  {
2127  	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
2128  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2129  	struct strp_msg *rxm = NULL;
2130  	struct sock *sk = sock->sk;
2131  	struct tls_msg *tlm;
2132  	struct sk_buff *skb;
2133  	ssize_t copied = 0;
2134  	int chunk;
2135  	int err;
2136  
2137  	err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
2138  	if (err < 0)
2139  		return err;
2140  
2141  	if (!skb_queue_empty(&ctx->rx_list)) {
2142  		skb = __skb_dequeue(&ctx->rx_list);
2143  	} else {
2144  		struct tls_decrypt_arg darg;
2145  
2146  		err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
2147  				      true);
2148  		if (err <= 0)
2149  			goto splice_read_end;
2150  
2151  		memset(&darg.inargs, 0, sizeof(darg.inargs));
2152  
2153  		err = tls_rx_one_record(sk, NULL, &darg);
2154  		if (err < 0) {
2155  			tls_err_abort(sk, -EBADMSG);
2156  			goto splice_read_end;
2157  		}
2158  
2159  		tls_rx_rec_done(ctx);
2160  		skb = darg.skb;
2161  	}
2162  
2163  	rxm = strp_msg(skb);
2164  	tlm = tls_msg(skb);
2165  
2166  	/* splice does not support reading control messages */
2167  	if (tlm->control != TLS_RECORD_TYPE_DATA) {
2168  		err = -EINVAL;
2169  		goto splice_requeue;
2170  	}
2171  
2172  	chunk = min_t(unsigned int, rxm->full_len, len);
2173  	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
2174  	if (copied < 0)
2175  		goto splice_requeue;
2176  
2177  	if (chunk < rxm->full_len) {
2178  		rxm->offset += len;
2179  		rxm->full_len -= len;
2180  		goto splice_requeue;
2181  	}
2182  
2183  	consume_skb(skb);
2184  
2185  splice_read_end:
2186  	tls_rx_reader_unlock(sk, ctx);
2187  	return copied ? : err;
2188  
2189  splice_requeue:
2190  	__skb_queue_head(&ctx->rx_list, skb);
2191  	goto splice_read_end;
2192  }
2193  
2194  bool tls_sw_sock_is_readable(struct sock *sk)
2195  {
2196  	struct tls_context *tls_ctx = tls_get_ctx(sk);
2197  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2198  	bool ingress_empty = true;
2199  	struct sk_psock *psock;
2200  
2201  	rcu_read_lock();
2202  	psock = sk_psock(sk);
2203  	if (psock)
2204  		ingress_empty = list_empty(&psock->ingress_msg);
2205  	rcu_read_unlock();
2206  
2207  	return !ingress_empty || tls_strp_msg_ready(ctx) ||
2208  		!skb_queue_empty(&ctx->rx_list);
2209  }
2210  
2211  int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
2212  {
2213  	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2214  	struct tls_prot_info *prot = &tls_ctx->prot_info;
2215  	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2216  	size_t cipher_overhead;
2217  	size_t data_len = 0;
2218  	int ret;
2219  
2220  	/* Verify that we have a full TLS header, or wait for more data */
2221  	if (strp->stm.offset + prot->prepend_size > skb->len)
2222  		return 0;
2223  
2224  	/* Sanity-check size of on-stack buffer. */
2225  	if (WARN_ON(prot->prepend_size > sizeof(header))) {
2226  		ret = -EINVAL;
2227  		goto read_failure;
2228  	}
2229  
2230  	/* Linearize header to local buffer */
2231  	ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size);
2232  	if (ret < 0)
2233  		goto read_failure;
2234  
2235  	strp->mark = header[0];
2236  
2237  	data_len = ((header[4] & 0xFF) | (header[3] << 8));
2238  
2239  	cipher_overhead = prot->tag_size;
2240  	if (prot->version != TLS_1_3_VERSION &&
2241  	    prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
2242  		cipher_overhead += prot->iv_size;
2243  
2244  	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2245  	    prot->tail_size) {
2246  		ret = -EMSGSIZE;
2247  		goto read_failure;
2248  	}
2249  	if (data_len < cipher_overhead) {
2250  		ret = -EBADMSG;
2251  		goto read_failure;
2252  	}
2253  
2254  	/* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2255  	if (header[1] != TLS_1_2_VERSION_MINOR ||
2256  	    header[2] != TLS_1_2_VERSION_MAJOR) {
2257  		ret = -EINVAL;
2258  		goto read_failure;
2259  	}
2260  
2261  	tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2262  				     TCP_SKB_CB(skb)->seq + strp->stm.offset);
2263  	return data_len + TLS_HEADER_SIZE;
2264  
2265  read_failure:
2266  	tls_err_abort(strp->sk, ret);
2267  
2268  	return ret;
2269  }
2270  
2271  void tls_rx_msg_ready(struct tls_strparser *strp)
2272  {
2273  	struct tls_sw_context_rx *ctx;
2274  
2275  	ctx = container_of(strp, struct tls_sw_context_rx, strp);
2276  	ctx->saved_data_ready(strp->sk);
2277  }
2278  
2279  static void tls_data_ready(struct sock *sk)
2280  {
2281  	struct tls_context *tls_ctx = tls_get_ctx(sk);
2282  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2283  	struct sk_psock *psock;
2284  
2285  	tls_strp_data_ready(&ctx->strp);
2286  
2287  	psock = sk_psock_get(sk);
2288  	if (psock) {
2289  		if (!list_empty(&psock->ingress_msg))
2290  			ctx->saved_data_ready(sk);
2291  		sk_psock_put(sk, psock);
2292  	}
2293  }
2294  
2295  void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2296  {
2297  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2298  
2299  	set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2300  	set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2301  	cancel_delayed_work_sync(&ctx->tx_work.work);
2302  }
2303  
2304  void tls_sw_release_resources_tx(struct sock *sk)
2305  {
2306  	struct tls_context *tls_ctx = tls_get_ctx(sk);
2307  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2308  	struct tls_rec *rec, *tmp;
2309  	int pending;
2310  
2311  	/* Wait for any pending async encryptions to complete */
2312  	spin_lock_bh(&ctx->encrypt_compl_lock);
2313  	ctx->async_notify = true;
2314  	pending = atomic_read(&ctx->encrypt_pending);
2315  	spin_unlock_bh(&ctx->encrypt_compl_lock);
2316  
2317  	if (pending)
2318  		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2319  
2320  	tls_tx_records(sk, -1);
2321  
2322  	/* Free up un-sent records in tx_list. First, free
2323  	 * the partially sent record if any at head of tx_list.
2324  	 */
2325  	if (tls_ctx->partially_sent_record) {
2326  		tls_free_partial_record(sk, tls_ctx);
2327  		rec = list_first_entry(&ctx->tx_list,
2328  				       struct tls_rec, list);
2329  		list_del(&rec->list);
2330  		sk_msg_free(sk, &rec->msg_plaintext);
2331  		kfree(rec);
2332  	}
2333  
2334  	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2335  		list_del(&rec->list);
2336  		sk_msg_free(sk, &rec->msg_encrypted);
2337  		sk_msg_free(sk, &rec->msg_plaintext);
2338  		kfree(rec);
2339  	}
2340  
2341  	crypto_free_aead(ctx->aead_send);
2342  	tls_free_open_rec(sk);
2343  }
2344  
2345  void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2346  {
2347  	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2348  
2349  	kfree(ctx);
2350  }
2351  
2352  void tls_sw_release_resources_rx(struct sock *sk)
2353  {
2354  	struct tls_context *tls_ctx = tls_get_ctx(sk);
2355  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2356  
2357  	kfree(tls_ctx->rx.rec_seq);
2358  	kfree(tls_ctx->rx.iv);
2359  
2360  	if (ctx->aead_recv) {
2361  		__skb_queue_purge(&ctx->rx_list);
2362  		crypto_free_aead(ctx->aead_recv);
2363  		tls_strp_stop(&ctx->strp);
2364  		/* If tls_sw_strparser_arm() was not called (cleanup paths)
2365  		 * we still want to tls_strp_stop(), but sk->sk_data_ready was
2366  		 * never swapped.
2367  		 */
2368  		if (ctx->saved_data_ready) {
2369  			write_lock_bh(&sk->sk_callback_lock);
2370  			sk->sk_data_ready = ctx->saved_data_ready;
2371  			write_unlock_bh(&sk->sk_callback_lock);
2372  		}
2373  	}
2374  }
2375  
2376  void tls_sw_strparser_done(struct tls_context *tls_ctx)
2377  {
2378  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2379  
2380  	tls_strp_done(&ctx->strp);
2381  }
2382  
2383  void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2384  {
2385  	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2386  
2387  	kfree(ctx);
2388  }
2389  
2390  void tls_sw_free_resources_rx(struct sock *sk)
2391  {
2392  	struct tls_context *tls_ctx = tls_get_ctx(sk);
2393  
2394  	tls_sw_release_resources_rx(sk);
2395  	tls_sw_free_ctx_rx(tls_ctx);
2396  }
2397  
2398  /* The work handler to transmitt the encrypted records in tx_list */
2399  static void tx_work_handler(struct work_struct *work)
2400  {
2401  	struct delayed_work *delayed_work = to_delayed_work(work);
2402  	struct tx_work *tx_work = container_of(delayed_work,
2403  					       struct tx_work, work);
2404  	struct sock *sk = tx_work->sk;
2405  	struct tls_context *tls_ctx = tls_get_ctx(sk);
2406  	struct tls_sw_context_tx *ctx;
2407  
2408  	if (unlikely(!tls_ctx))
2409  		return;
2410  
2411  	ctx = tls_sw_ctx_tx(tls_ctx);
2412  	if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2413  		return;
2414  
2415  	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2416  		return;
2417  	mutex_lock(&tls_ctx->tx_lock);
2418  	lock_sock(sk);
2419  	tls_tx_records(sk, -1);
2420  	release_sock(sk);
2421  	mutex_unlock(&tls_ctx->tx_lock);
2422  }
2423  
2424  static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
2425  {
2426  	struct tls_rec *rec;
2427  
2428  	rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
2429  	if (!rec)
2430  		return false;
2431  
2432  	return READ_ONCE(rec->tx_ready);
2433  }
2434  
2435  void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2436  {
2437  	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2438  
2439  	/* Schedule the transmission if tx list is ready */
2440  	if (tls_is_tx_ready(tx_ctx) &&
2441  	    !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2442  		schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2443  }
2444  
2445  void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2446  {
2447  	struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2448  
2449  	write_lock_bh(&sk->sk_callback_lock);
2450  	rx_ctx->saved_data_ready = sk->sk_data_ready;
2451  	sk->sk_data_ready = tls_data_ready;
2452  	write_unlock_bh(&sk->sk_callback_lock);
2453  }
2454  
2455  void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
2456  {
2457  	struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2458  
2459  	rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
2460  		tls_ctx->prot_info.version != TLS_1_3_VERSION;
2461  }
2462  
2463  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2464  {
2465  	struct tls_context *tls_ctx = tls_get_ctx(sk);
2466  	struct tls_prot_info *prot = &tls_ctx->prot_info;
2467  	struct tls_crypto_info *crypto_info;
2468  	struct tls_sw_context_tx *sw_ctx_tx = NULL;
2469  	struct tls_sw_context_rx *sw_ctx_rx = NULL;
2470  	struct cipher_context *cctx;
2471  	struct crypto_aead **aead;
2472  	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2473  	struct crypto_tfm *tfm;
2474  	char *iv, *rec_seq, *key, *salt, *cipher_name;
2475  	size_t keysize;
2476  	int rc = 0;
2477  
2478  	if (!ctx) {
2479  		rc = -EINVAL;
2480  		goto out;
2481  	}
2482  
2483  	if (tx) {
2484  		if (!ctx->priv_ctx_tx) {
2485  			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2486  			if (!sw_ctx_tx) {
2487  				rc = -ENOMEM;
2488  				goto out;
2489  			}
2490  			ctx->priv_ctx_tx = sw_ctx_tx;
2491  		} else {
2492  			sw_ctx_tx =
2493  				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2494  		}
2495  	} else {
2496  		if (!ctx->priv_ctx_rx) {
2497  			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2498  			if (!sw_ctx_rx) {
2499  				rc = -ENOMEM;
2500  				goto out;
2501  			}
2502  			ctx->priv_ctx_rx = sw_ctx_rx;
2503  		} else {
2504  			sw_ctx_rx =
2505  				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2506  		}
2507  	}
2508  
2509  	if (tx) {
2510  		crypto_init_wait(&sw_ctx_tx->async_wait);
2511  		spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2512  		crypto_info = &ctx->crypto_send.info;
2513  		cctx = &ctx->tx;
2514  		aead = &sw_ctx_tx->aead_send;
2515  		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2516  		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2517  		sw_ctx_tx->tx_work.sk = sk;
2518  	} else {
2519  		crypto_init_wait(&sw_ctx_rx->async_wait);
2520  		spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2521  		init_waitqueue_head(&sw_ctx_rx->wq);
2522  		crypto_info = &ctx->crypto_recv.info;
2523  		cctx = &ctx->rx;
2524  		skb_queue_head_init(&sw_ctx_rx->rx_list);
2525  		skb_queue_head_init(&sw_ctx_rx->async_hold);
2526  		aead = &sw_ctx_rx->aead_recv;
2527  	}
2528  
2529  	switch (crypto_info->cipher_type) {
2530  	case TLS_CIPHER_AES_GCM_128: {
2531  		struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2532  
2533  		gcm_128_info = (void *)crypto_info;
2534  		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2535  		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
2536  		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2537  		iv = gcm_128_info->iv;
2538  		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2539  		rec_seq = gcm_128_info->rec_seq;
2540  		keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2541  		key = gcm_128_info->key;
2542  		salt = gcm_128_info->salt;
2543  		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2544  		cipher_name = "gcm(aes)";
2545  		break;
2546  	}
2547  	case TLS_CIPHER_AES_GCM_256: {
2548  		struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2549  
2550  		gcm_256_info = (void *)crypto_info;
2551  		nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2552  		tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2553  		iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2554  		iv = gcm_256_info->iv;
2555  		rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2556  		rec_seq = gcm_256_info->rec_seq;
2557  		keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2558  		key = gcm_256_info->key;
2559  		salt = gcm_256_info->salt;
2560  		salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2561  		cipher_name = "gcm(aes)";
2562  		break;
2563  	}
2564  	case TLS_CIPHER_AES_CCM_128: {
2565  		struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
2566  
2567  		ccm_128_info = (void *)crypto_info;
2568  		nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2569  		tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2570  		iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2571  		iv = ccm_128_info->iv;
2572  		rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2573  		rec_seq = ccm_128_info->rec_seq;
2574  		keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2575  		key = ccm_128_info->key;
2576  		salt = ccm_128_info->salt;
2577  		salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2578  		cipher_name = "ccm(aes)";
2579  		break;
2580  	}
2581  	case TLS_CIPHER_CHACHA20_POLY1305: {
2582  		struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
2583  
2584  		chacha20_poly1305_info = (void *)crypto_info;
2585  		nonce_size = 0;
2586  		tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
2587  		iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE;
2588  		iv = chacha20_poly1305_info->iv;
2589  		rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE;
2590  		rec_seq = chacha20_poly1305_info->rec_seq;
2591  		keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE;
2592  		key = chacha20_poly1305_info->key;
2593  		salt = chacha20_poly1305_info->salt;
2594  		salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE;
2595  		cipher_name = "rfc7539(chacha20,poly1305)";
2596  		break;
2597  	}
2598  	case TLS_CIPHER_SM4_GCM: {
2599  		struct tls12_crypto_info_sm4_gcm *sm4_gcm_info;
2600  
2601  		sm4_gcm_info = (void *)crypto_info;
2602  		nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2603  		tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE;
2604  		iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2605  		iv = sm4_gcm_info->iv;
2606  		rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE;
2607  		rec_seq = sm4_gcm_info->rec_seq;
2608  		keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE;
2609  		key = sm4_gcm_info->key;
2610  		salt = sm4_gcm_info->salt;
2611  		salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE;
2612  		cipher_name = "gcm(sm4)";
2613  		break;
2614  	}
2615  	case TLS_CIPHER_SM4_CCM: {
2616  		struct tls12_crypto_info_sm4_ccm *sm4_ccm_info;
2617  
2618  		sm4_ccm_info = (void *)crypto_info;
2619  		nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2620  		tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE;
2621  		iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2622  		iv = sm4_ccm_info->iv;
2623  		rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE;
2624  		rec_seq = sm4_ccm_info->rec_seq;
2625  		keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE;
2626  		key = sm4_ccm_info->key;
2627  		salt = sm4_ccm_info->salt;
2628  		salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE;
2629  		cipher_name = "ccm(sm4)";
2630  		break;
2631  	}
2632  	default:
2633  		rc = -EINVAL;
2634  		goto free_priv;
2635  	}
2636  
2637  	if (crypto_info->version == TLS_1_3_VERSION) {
2638  		nonce_size = 0;
2639  		prot->aad_size = TLS_HEADER_SIZE;
2640  		prot->tail_size = 1;
2641  	} else {
2642  		prot->aad_size = TLS_AAD_SPACE_SIZE;
2643  		prot->tail_size = 0;
2644  	}
2645  
2646  	/* Sanity-check the sizes for stack allocations. */
2647  	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2648  	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
2649  	    prot->aad_size > TLS_MAX_AAD_SIZE) {
2650  		rc = -EINVAL;
2651  		goto free_priv;
2652  	}
2653  
2654  	prot->version = crypto_info->version;
2655  	prot->cipher_type = crypto_info->cipher_type;
2656  	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2657  	prot->tag_size = tag_size;
2658  	prot->overhead_size = prot->prepend_size +
2659  			      prot->tag_size + prot->tail_size;
2660  	prot->iv_size = iv_size;
2661  	prot->salt_size = salt_size;
2662  	cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
2663  	if (!cctx->iv) {
2664  		rc = -ENOMEM;
2665  		goto free_priv;
2666  	}
2667  	/* Note: 128 & 256 bit salt are the same size */
2668  	prot->rec_seq_size = rec_seq_size;
2669  	memcpy(cctx->iv, salt, salt_size);
2670  	memcpy(cctx->iv + salt_size, iv, iv_size);
2671  	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
2672  	if (!cctx->rec_seq) {
2673  		rc = -ENOMEM;
2674  		goto free_iv;
2675  	}
2676  
2677  	if (!*aead) {
2678  		*aead = crypto_alloc_aead(cipher_name, 0, 0);
2679  		if (IS_ERR(*aead)) {
2680  			rc = PTR_ERR(*aead);
2681  			*aead = NULL;
2682  			goto free_rec_seq;
2683  		}
2684  	}
2685  
2686  	ctx->push_pending_record = tls_sw_push_pending_record;
2687  
2688  	rc = crypto_aead_setkey(*aead, key, keysize);
2689  
2690  	if (rc)
2691  		goto free_aead;
2692  
2693  	rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2694  	if (rc)
2695  		goto free_aead;
2696  
2697  	if (sw_ctx_rx) {
2698  		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2699  
2700  		tls_update_rx_zc_capable(ctx);
2701  		sw_ctx_rx->async_capable =
2702  			crypto_info->version != TLS_1_3_VERSION &&
2703  			!!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
2704  
2705  		rc = tls_strp_init(&sw_ctx_rx->strp, sk);
2706  		if (rc)
2707  			goto free_aead;
2708  	}
2709  
2710  	goto out;
2711  
2712  free_aead:
2713  	crypto_free_aead(*aead);
2714  	*aead = NULL;
2715  free_rec_seq:
2716  	kfree(cctx->rec_seq);
2717  	cctx->rec_seq = NULL;
2718  free_iv:
2719  	kfree(cctx->iv);
2720  	cctx->iv = NULL;
2721  free_priv:
2722  	if (tx) {
2723  		kfree(ctx->priv_ctx_tx);
2724  		ctx->priv_ctx_tx = NULL;
2725  	} else {
2726  		kfree(ctx->priv_ctx_rx);
2727  		ctx->priv_ctx_rx = NULL;
2728  	}
2729  out:
2730  	return rc;
2731  }
2732