xref: /openbmc/linux/net/tls/tls_sw.c (revision 68198dca)
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36 
37 #include <linux/module.h>
38 #include <crypto/aead.h>
39 
40 #include <net/tls.h>
41 
42 static void trim_sg(struct sock *sk, struct scatterlist *sg,
43 		    int *sg_num_elem, unsigned int *sg_size, int target_size)
44 {
45 	int i = *sg_num_elem - 1;
46 	int trim = *sg_size - target_size;
47 
48 	if (trim <= 0) {
49 		WARN_ON(trim < 0);
50 		return;
51 	}
52 
53 	*sg_size = target_size;
54 	while (trim >= sg[i].length) {
55 		trim -= sg[i].length;
56 		sk_mem_uncharge(sk, sg[i].length);
57 		put_page(sg_page(&sg[i]));
58 		i--;
59 
60 		if (i < 0)
61 			goto out;
62 	}
63 
64 	sg[i].length -= trim;
65 	sk_mem_uncharge(sk, trim);
66 
67 out:
68 	*sg_num_elem = i + 1;
69 }
70 
71 static void trim_both_sgl(struct sock *sk, int target_size)
72 {
73 	struct tls_context *tls_ctx = tls_get_ctx(sk);
74 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
75 
76 	trim_sg(sk, ctx->sg_plaintext_data,
77 		&ctx->sg_plaintext_num_elem,
78 		&ctx->sg_plaintext_size,
79 		target_size);
80 
81 	if (target_size > 0)
82 		target_size += tls_ctx->overhead_size;
83 
84 	trim_sg(sk, ctx->sg_encrypted_data,
85 		&ctx->sg_encrypted_num_elem,
86 		&ctx->sg_encrypted_size,
87 		target_size);
88 }
89 
90 static int alloc_sg(struct sock *sk, int len, struct scatterlist *sg,
91 		    int *sg_num_elem, unsigned int *sg_size,
92 		    int first_coalesce)
93 {
94 	struct page_frag *pfrag;
95 	unsigned int size = *sg_size;
96 	int num_elem = *sg_num_elem, use = 0, rc = 0;
97 	struct scatterlist *sge;
98 	unsigned int orig_offset;
99 
100 	len -= size;
101 	pfrag = sk_page_frag(sk);
102 
103 	while (len > 0) {
104 		if (!sk_page_frag_refill(sk, pfrag)) {
105 			rc = -ENOMEM;
106 			goto out;
107 		}
108 
109 		use = min_t(int, len, pfrag->size - pfrag->offset);
110 
111 		if (!sk_wmem_schedule(sk, use)) {
112 			rc = -ENOMEM;
113 			goto out;
114 		}
115 
116 		sk_mem_charge(sk, use);
117 		size += use;
118 		orig_offset = pfrag->offset;
119 		pfrag->offset += use;
120 
121 		sge = sg + num_elem - 1;
122 		if (num_elem > first_coalesce && sg_page(sg) == pfrag->page &&
123 		    sg->offset + sg->length == orig_offset) {
124 			sg->length += use;
125 		} else {
126 			sge++;
127 			sg_unmark_end(sge);
128 			sg_set_page(sge, pfrag->page, use, orig_offset);
129 			get_page(pfrag->page);
130 			++num_elem;
131 			if (num_elem == MAX_SKB_FRAGS) {
132 				rc = -ENOSPC;
133 				break;
134 			}
135 		}
136 
137 		len -= use;
138 	}
139 	goto out;
140 
141 out:
142 	*sg_size = size;
143 	*sg_num_elem = num_elem;
144 	return rc;
145 }
146 
147 static int alloc_encrypted_sg(struct sock *sk, int len)
148 {
149 	struct tls_context *tls_ctx = tls_get_ctx(sk);
150 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
151 	int rc = 0;
152 
153 	rc = alloc_sg(sk, len, ctx->sg_encrypted_data,
154 		      &ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_size, 0);
155 
156 	return rc;
157 }
158 
159 static int alloc_plaintext_sg(struct sock *sk, int len)
160 {
161 	struct tls_context *tls_ctx = tls_get_ctx(sk);
162 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
163 	int rc = 0;
164 
165 	rc = alloc_sg(sk, len, ctx->sg_plaintext_data,
166 		      &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
167 		      tls_ctx->pending_open_record_frags);
168 
169 	return rc;
170 }
171 
172 static void free_sg(struct sock *sk, struct scatterlist *sg,
173 		    int *sg_num_elem, unsigned int *sg_size)
174 {
175 	int i, n = *sg_num_elem;
176 
177 	for (i = 0; i < n; ++i) {
178 		sk_mem_uncharge(sk, sg[i].length);
179 		put_page(sg_page(&sg[i]));
180 	}
181 	*sg_num_elem = 0;
182 	*sg_size = 0;
183 }
184 
185 static void tls_free_both_sg(struct sock *sk)
186 {
187 	struct tls_context *tls_ctx = tls_get_ctx(sk);
188 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
189 
190 	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
191 		&ctx->sg_encrypted_size);
192 
193 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
194 		&ctx->sg_plaintext_size);
195 }
196 
197 static int tls_do_encryption(struct tls_context *tls_ctx,
198 			     struct tls_sw_context *ctx, size_t data_len,
199 			     gfp_t flags)
200 {
201 	unsigned int req_size = sizeof(struct aead_request) +
202 		crypto_aead_reqsize(ctx->aead_send);
203 	struct aead_request *aead_req;
204 	int rc;
205 
206 	aead_req = kzalloc(req_size, flags);
207 	if (!aead_req)
208 		return -ENOMEM;
209 
210 	ctx->sg_encrypted_data[0].offset += tls_ctx->prepend_size;
211 	ctx->sg_encrypted_data[0].length -= tls_ctx->prepend_size;
212 
213 	aead_request_set_tfm(aead_req, ctx->aead_send);
214 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
215 	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
216 			       data_len, tls_ctx->iv);
217 	rc = crypto_aead_encrypt(aead_req);
218 
219 	ctx->sg_encrypted_data[0].offset -= tls_ctx->prepend_size;
220 	ctx->sg_encrypted_data[0].length += tls_ctx->prepend_size;
221 
222 	kfree(aead_req);
223 	return rc;
224 }
225 
226 static int tls_push_record(struct sock *sk, int flags,
227 			   unsigned char record_type)
228 {
229 	struct tls_context *tls_ctx = tls_get_ctx(sk);
230 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
231 	int rc;
232 
233 	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
234 	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
235 
236 	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
237 		     tls_ctx->rec_seq, tls_ctx->rec_seq_size,
238 		     record_type);
239 
240 	tls_fill_prepend(tls_ctx,
241 			 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
242 			 ctx->sg_encrypted_data[0].offset,
243 			 ctx->sg_plaintext_size, record_type);
244 
245 	tls_ctx->pending_open_record_frags = 0;
246 	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
247 
248 	rc = tls_do_encryption(tls_ctx, ctx, ctx->sg_plaintext_size,
249 			       sk->sk_allocation);
250 	if (rc < 0) {
251 		/* If we are called from write_space and
252 		 * we fail, we need to set this SOCK_NOSPACE
253 		 * to trigger another write_space in the future.
254 		 */
255 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
256 		return rc;
257 	}
258 
259 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
260 		&ctx->sg_plaintext_size);
261 
262 	ctx->sg_encrypted_num_elem = 0;
263 	ctx->sg_encrypted_size = 0;
264 
265 	/* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
266 	rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
267 	if (rc < 0 && rc != -EAGAIN)
268 		tls_err_abort(sk);
269 
270 	tls_advance_record_sn(sk, tls_ctx);
271 	return rc;
272 }
273 
274 static int tls_sw_push_pending_record(struct sock *sk, int flags)
275 {
276 	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
277 }
278 
279 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
280 			      int length)
281 {
282 	struct tls_context *tls_ctx = tls_get_ctx(sk);
283 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
284 	struct page *pages[MAX_SKB_FRAGS];
285 
286 	size_t offset;
287 	ssize_t copied, use;
288 	int i = 0;
289 	unsigned int size = ctx->sg_plaintext_size;
290 	int num_elem = ctx->sg_plaintext_num_elem;
291 	int rc = 0;
292 	int maxpages;
293 
294 	while (length > 0) {
295 		i = 0;
296 		maxpages = ARRAY_SIZE(ctx->sg_plaintext_data) - num_elem;
297 		if (maxpages == 0) {
298 			rc = -EFAULT;
299 			goto out;
300 		}
301 		copied = iov_iter_get_pages(from, pages,
302 					    length,
303 					    maxpages, &offset);
304 		if (copied <= 0) {
305 			rc = -EFAULT;
306 			goto out;
307 		}
308 
309 		iov_iter_advance(from, copied);
310 
311 		length -= copied;
312 		size += copied;
313 		while (copied) {
314 			use = min_t(int, copied, PAGE_SIZE - offset);
315 
316 			sg_set_page(&ctx->sg_plaintext_data[num_elem],
317 				    pages[i], use, offset);
318 			sg_unmark_end(&ctx->sg_plaintext_data[num_elem]);
319 			sk_mem_charge(sk, use);
320 
321 			offset = 0;
322 			copied -= use;
323 
324 			++i;
325 			++num_elem;
326 		}
327 	}
328 
329 out:
330 	ctx->sg_plaintext_size = size;
331 	ctx->sg_plaintext_num_elem = num_elem;
332 	return rc;
333 }
334 
335 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
336 			     int bytes)
337 {
338 	struct tls_context *tls_ctx = tls_get_ctx(sk);
339 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
340 	struct scatterlist *sg = ctx->sg_plaintext_data;
341 	int copy, i, rc = 0;
342 
343 	for (i = tls_ctx->pending_open_record_frags;
344 	     i < ctx->sg_plaintext_num_elem; ++i) {
345 		copy = sg[i].length;
346 		if (copy_from_iter(
347 				page_address(sg_page(&sg[i])) + sg[i].offset,
348 				copy, from) != copy) {
349 			rc = -EFAULT;
350 			goto out;
351 		}
352 		bytes -= copy;
353 
354 		++tls_ctx->pending_open_record_frags;
355 
356 		if (!bytes)
357 			break;
358 	}
359 
360 out:
361 	return rc;
362 }
363 
364 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
365 {
366 	struct tls_context *tls_ctx = tls_get_ctx(sk);
367 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
368 	int ret = 0;
369 	int required_size;
370 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
371 	bool eor = !(msg->msg_flags & MSG_MORE);
372 	size_t try_to_copy, copied = 0;
373 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
374 	int record_room;
375 	bool full_record;
376 	int orig_size;
377 
378 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
379 		return -ENOTSUPP;
380 
381 	lock_sock(sk);
382 
383 	if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
384 		goto send_end;
385 
386 	if (unlikely(msg->msg_controllen)) {
387 		ret = tls_proccess_cmsg(sk, msg, &record_type);
388 		if (ret)
389 			goto send_end;
390 	}
391 
392 	while (msg_data_left(msg)) {
393 		if (sk->sk_err) {
394 			ret = sk->sk_err;
395 			goto send_end;
396 		}
397 
398 		orig_size = ctx->sg_plaintext_size;
399 		full_record = false;
400 		try_to_copy = msg_data_left(msg);
401 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
402 		if (try_to_copy >= record_room) {
403 			try_to_copy = record_room;
404 			full_record = true;
405 		}
406 
407 		required_size = ctx->sg_plaintext_size + try_to_copy +
408 				tls_ctx->overhead_size;
409 
410 		if (!sk_stream_memory_free(sk))
411 			goto wait_for_sndbuf;
412 alloc_encrypted:
413 		ret = alloc_encrypted_sg(sk, required_size);
414 		if (ret) {
415 			if (ret != -ENOSPC)
416 				goto wait_for_memory;
417 
418 			/* Adjust try_to_copy according to the amount that was
419 			 * actually allocated. The difference is due
420 			 * to max sg elements limit
421 			 */
422 			try_to_copy -= required_size - ctx->sg_encrypted_size;
423 			full_record = true;
424 		}
425 
426 		if (full_record || eor) {
427 			ret = zerocopy_from_iter(sk, &msg->msg_iter,
428 						 try_to_copy);
429 			if (ret)
430 				goto fallback_to_reg_send;
431 
432 			copied += try_to_copy;
433 			ret = tls_push_record(sk, msg->msg_flags, record_type);
434 			if (!ret)
435 				continue;
436 			if (ret == -EAGAIN)
437 				goto send_end;
438 
439 			copied -= try_to_copy;
440 fallback_to_reg_send:
441 			iov_iter_revert(&msg->msg_iter,
442 					ctx->sg_plaintext_size - orig_size);
443 			trim_sg(sk, ctx->sg_plaintext_data,
444 				&ctx->sg_plaintext_num_elem,
445 				&ctx->sg_plaintext_size,
446 				orig_size);
447 		}
448 
449 		required_size = ctx->sg_plaintext_size + try_to_copy;
450 alloc_plaintext:
451 		ret = alloc_plaintext_sg(sk, required_size);
452 		if (ret) {
453 			if (ret != -ENOSPC)
454 				goto wait_for_memory;
455 
456 			/* Adjust try_to_copy according to the amount that was
457 			 * actually allocated. The difference is due
458 			 * to max sg elements limit
459 			 */
460 			try_to_copy -= required_size - ctx->sg_plaintext_size;
461 			full_record = true;
462 
463 			trim_sg(sk, ctx->sg_encrypted_data,
464 				&ctx->sg_encrypted_num_elem,
465 				&ctx->sg_encrypted_size,
466 				ctx->sg_plaintext_size +
467 				tls_ctx->overhead_size);
468 		}
469 
470 		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
471 		if (ret)
472 			goto trim_sgl;
473 
474 		copied += try_to_copy;
475 		if (full_record || eor) {
476 push_record:
477 			ret = tls_push_record(sk, msg->msg_flags, record_type);
478 			if (ret) {
479 				if (ret == -ENOMEM)
480 					goto wait_for_memory;
481 
482 				goto send_end;
483 			}
484 		}
485 
486 		continue;
487 
488 wait_for_sndbuf:
489 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
490 wait_for_memory:
491 		ret = sk_stream_wait_memory(sk, &timeo);
492 		if (ret) {
493 trim_sgl:
494 			trim_both_sgl(sk, orig_size);
495 			goto send_end;
496 		}
497 
498 		if (tls_is_pending_closed_record(tls_ctx))
499 			goto push_record;
500 
501 		if (ctx->sg_encrypted_size < required_size)
502 			goto alloc_encrypted;
503 
504 		goto alloc_plaintext;
505 	}
506 
507 send_end:
508 	ret = sk_stream_error(sk, msg->msg_flags, ret);
509 
510 	release_sock(sk);
511 	return copied ? copied : ret;
512 }
513 
514 int tls_sw_sendpage(struct sock *sk, struct page *page,
515 		    int offset, size_t size, int flags)
516 {
517 	struct tls_context *tls_ctx = tls_get_ctx(sk);
518 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
519 	int ret = 0;
520 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
521 	bool eor;
522 	size_t orig_size = size;
523 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
524 	struct scatterlist *sg;
525 	bool full_record;
526 	int record_room;
527 
528 	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
529 		      MSG_SENDPAGE_NOTLAST))
530 		return -ENOTSUPP;
531 
532 	/* No MSG_EOR from splice, only look at MSG_MORE */
533 	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
534 
535 	lock_sock(sk);
536 
537 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
538 
539 	if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
540 		goto sendpage_end;
541 
542 	/* Call the sk_stream functions to manage the sndbuf mem. */
543 	while (size > 0) {
544 		size_t copy, required_size;
545 
546 		if (sk->sk_err) {
547 			ret = sk->sk_err;
548 			goto sendpage_end;
549 		}
550 
551 		full_record = false;
552 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
553 		copy = size;
554 		if (copy >= record_room) {
555 			copy = record_room;
556 			full_record = true;
557 		}
558 		required_size = ctx->sg_plaintext_size + copy +
559 			      tls_ctx->overhead_size;
560 
561 		if (!sk_stream_memory_free(sk))
562 			goto wait_for_sndbuf;
563 alloc_payload:
564 		ret = alloc_encrypted_sg(sk, required_size);
565 		if (ret) {
566 			if (ret != -ENOSPC)
567 				goto wait_for_memory;
568 
569 			/* Adjust copy according to the amount that was
570 			 * actually allocated. The difference is due
571 			 * to max sg elements limit
572 			 */
573 			copy -= required_size - ctx->sg_plaintext_size;
574 			full_record = true;
575 		}
576 
577 		get_page(page);
578 		sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
579 		sg_set_page(sg, page, copy, offset);
580 		ctx->sg_plaintext_num_elem++;
581 
582 		sk_mem_charge(sk, copy);
583 		offset += copy;
584 		size -= copy;
585 		ctx->sg_plaintext_size += copy;
586 		tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
587 
588 		if (full_record || eor ||
589 		    ctx->sg_plaintext_num_elem ==
590 		    ARRAY_SIZE(ctx->sg_plaintext_data)) {
591 push_record:
592 			ret = tls_push_record(sk, flags, record_type);
593 			if (ret) {
594 				if (ret == -ENOMEM)
595 					goto wait_for_memory;
596 
597 				goto sendpage_end;
598 			}
599 		}
600 		continue;
601 wait_for_sndbuf:
602 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
603 wait_for_memory:
604 		ret = sk_stream_wait_memory(sk, &timeo);
605 		if (ret) {
606 			trim_both_sgl(sk, ctx->sg_plaintext_size);
607 			goto sendpage_end;
608 		}
609 
610 		if (tls_is_pending_closed_record(tls_ctx))
611 			goto push_record;
612 
613 		goto alloc_payload;
614 	}
615 
616 sendpage_end:
617 	if (orig_size > size)
618 		ret = orig_size - size;
619 	else
620 		ret = sk_stream_error(sk, flags, ret);
621 
622 	release_sock(sk);
623 	return ret;
624 }
625 
626 void tls_sw_free_tx_resources(struct sock *sk)
627 {
628 	struct tls_context *tls_ctx = tls_get_ctx(sk);
629 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
630 
631 	if (ctx->aead_send)
632 		crypto_free_aead(ctx->aead_send);
633 
634 	tls_free_both_sg(sk);
635 
636 	kfree(ctx);
637 	kfree(tls_ctx);
638 }
639 
640 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
641 {
642 	char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
643 	struct tls_crypto_info *crypto_info;
644 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
645 	struct tls_sw_context *sw_ctx;
646 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
647 	char *iv, *rec_seq;
648 	int rc = 0;
649 
650 	if (!ctx) {
651 		rc = -EINVAL;
652 		goto out;
653 	}
654 
655 	if (ctx->priv_ctx) {
656 		rc = -EEXIST;
657 		goto out;
658 	}
659 
660 	sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
661 	if (!sw_ctx) {
662 		rc = -ENOMEM;
663 		goto out;
664 	}
665 
666 	ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
667 
668 	crypto_info = &ctx->crypto_send;
669 	switch (crypto_info->cipher_type) {
670 	case TLS_CIPHER_AES_GCM_128: {
671 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
672 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
673 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
674 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
675 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
676 		rec_seq =
677 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
678 		gcm_128_info =
679 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
680 		break;
681 	}
682 	default:
683 		rc = -EINVAL;
684 		goto out;
685 	}
686 
687 	ctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
688 	ctx->tag_size = tag_size;
689 	ctx->overhead_size = ctx->prepend_size + ctx->tag_size;
690 	ctx->iv_size = iv_size;
691 	ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
692 			  GFP_KERNEL);
693 	if (!ctx->iv) {
694 		rc = -ENOMEM;
695 		goto out;
696 	}
697 	memcpy(ctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
698 	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
699 	ctx->rec_seq_size = rec_seq_size;
700 	ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
701 	if (!ctx->rec_seq) {
702 		rc = -ENOMEM;
703 		goto free_iv;
704 	}
705 	memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
706 
707 	sg_init_table(sw_ctx->sg_encrypted_data,
708 		      ARRAY_SIZE(sw_ctx->sg_encrypted_data));
709 	sg_init_table(sw_ctx->sg_plaintext_data,
710 		      ARRAY_SIZE(sw_ctx->sg_plaintext_data));
711 
712 	sg_init_table(sw_ctx->sg_aead_in, 2);
713 	sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
714 		   sizeof(sw_ctx->aad_space));
715 	sg_unmark_end(&sw_ctx->sg_aead_in[1]);
716 	sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
717 	sg_init_table(sw_ctx->sg_aead_out, 2);
718 	sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
719 		   sizeof(sw_ctx->aad_space));
720 	sg_unmark_end(&sw_ctx->sg_aead_out[1]);
721 	sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
722 
723 	if (!sw_ctx->aead_send) {
724 		sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0);
725 		if (IS_ERR(sw_ctx->aead_send)) {
726 			rc = PTR_ERR(sw_ctx->aead_send);
727 			sw_ctx->aead_send = NULL;
728 			goto free_rec_seq;
729 		}
730 	}
731 
732 	ctx->push_pending_record = tls_sw_push_pending_record;
733 
734 	memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
735 
736 	rc = crypto_aead_setkey(sw_ctx->aead_send, keyval,
737 				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
738 	if (rc)
739 		goto free_aead;
740 
741 	rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tag_size);
742 	if (!rc)
743 		goto out;
744 
745 free_aead:
746 	crypto_free_aead(sw_ctx->aead_send);
747 	sw_ctx->aead_send = NULL;
748 free_rec_seq:
749 	kfree(ctx->rec_seq);
750 	ctx->rec_seq = NULL;
751 free_iv:
752 	kfree(ctx->iv);
753 	ctx->iv = NULL;
754 out:
755 	return rc;
756 }
757