xref: /openbmc/linux/net/tls/tls_sw.c (revision 981ab3f1)
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36 
37 #include <linux/module.h>
38 #include <crypto/aead.h>
39 
40 #include <net/tls.h>
41 
42 static inline void tls_make_aad(int recv,
43 				char *buf,
44 				size_t size,
45 				char *record_sequence,
46 				int record_sequence_size,
47 				unsigned char record_type)
48 {
49 	memcpy(buf, record_sequence, record_sequence_size);
50 
51 	buf[8] = record_type;
52 	buf[9] = TLS_1_2_VERSION_MAJOR;
53 	buf[10] = TLS_1_2_VERSION_MINOR;
54 	buf[11] = size >> 8;
55 	buf[12] = size & 0xFF;
56 }
57 
58 static void trim_sg(struct sock *sk, struct scatterlist *sg,
59 		    int *sg_num_elem, unsigned int *sg_size, int target_size)
60 {
61 	int i = *sg_num_elem - 1;
62 	int trim = *sg_size - target_size;
63 
64 	if (trim <= 0) {
65 		WARN_ON(trim < 0);
66 		return;
67 	}
68 
69 	*sg_size = target_size;
70 	while (trim >= sg[i].length) {
71 		trim -= sg[i].length;
72 		sk_mem_uncharge(sk, sg[i].length);
73 		put_page(sg_page(&sg[i]));
74 		i--;
75 
76 		if (i < 0)
77 			goto out;
78 	}
79 
80 	sg[i].length -= trim;
81 	sk_mem_uncharge(sk, trim);
82 
83 out:
84 	*sg_num_elem = i + 1;
85 }
86 
87 static void trim_both_sgl(struct sock *sk, int target_size)
88 {
89 	struct tls_context *tls_ctx = tls_get_ctx(sk);
90 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
91 
92 	trim_sg(sk, ctx->sg_plaintext_data,
93 		&ctx->sg_plaintext_num_elem,
94 		&ctx->sg_plaintext_size,
95 		target_size);
96 
97 	if (target_size > 0)
98 		target_size += tls_ctx->overhead_size;
99 
100 	trim_sg(sk, ctx->sg_encrypted_data,
101 		&ctx->sg_encrypted_num_elem,
102 		&ctx->sg_encrypted_size,
103 		target_size);
104 }
105 
106 static int alloc_sg(struct sock *sk, int len, struct scatterlist *sg,
107 		    int *sg_num_elem, unsigned int *sg_size,
108 		    int first_coalesce)
109 {
110 	struct page_frag *pfrag;
111 	unsigned int size = *sg_size;
112 	int num_elem = *sg_num_elem, use = 0, rc = 0;
113 	struct scatterlist *sge;
114 	unsigned int orig_offset;
115 
116 	len -= size;
117 	pfrag = sk_page_frag(sk);
118 
119 	while (len > 0) {
120 		if (!sk_page_frag_refill(sk, pfrag)) {
121 			rc = -ENOMEM;
122 			goto out;
123 		}
124 
125 		use = min_t(int, len, pfrag->size - pfrag->offset);
126 
127 		if (!sk_wmem_schedule(sk, use)) {
128 			rc = -ENOMEM;
129 			goto out;
130 		}
131 
132 		sk_mem_charge(sk, use);
133 		size += use;
134 		orig_offset = pfrag->offset;
135 		pfrag->offset += use;
136 
137 		sge = sg + num_elem - 1;
138 		if (num_elem > first_coalesce && sg_page(sg) == pfrag->page &&
139 		    sg->offset + sg->length == orig_offset) {
140 			sg->length += use;
141 		} else {
142 			sge++;
143 			sg_unmark_end(sge);
144 			sg_set_page(sge, pfrag->page, use, orig_offset);
145 			get_page(pfrag->page);
146 			++num_elem;
147 			if (num_elem == MAX_SKB_FRAGS) {
148 				rc = -ENOSPC;
149 				break;
150 			}
151 		}
152 
153 		len -= use;
154 	}
155 	goto out;
156 
157 out:
158 	*sg_size = size;
159 	*sg_num_elem = num_elem;
160 	return rc;
161 }
162 
163 static int alloc_encrypted_sg(struct sock *sk, int len)
164 {
165 	struct tls_context *tls_ctx = tls_get_ctx(sk);
166 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
167 	int rc = 0;
168 
169 	rc = alloc_sg(sk, len, ctx->sg_encrypted_data,
170 		      &ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_size, 0);
171 
172 	return rc;
173 }
174 
175 static int alloc_plaintext_sg(struct sock *sk, int len)
176 {
177 	struct tls_context *tls_ctx = tls_get_ctx(sk);
178 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
179 	int rc = 0;
180 
181 	rc = alloc_sg(sk, len, ctx->sg_plaintext_data,
182 		      &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
183 		      tls_ctx->pending_open_record_frags);
184 
185 	return rc;
186 }
187 
188 static void free_sg(struct sock *sk, struct scatterlist *sg,
189 		    int *sg_num_elem, unsigned int *sg_size)
190 {
191 	int i, n = *sg_num_elem;
192 
193 	for (i = 0; i < n; ++i) {
194 		sk_mem_uncharge(sk, sg[i].length);
195 		put_page(sg_page(&sg[i]));
196 	}
197 	*sg_num_elem = 0;
198 	*sg_size = 0;
199 }
200 
201 static void tls_free_both_sg(struct sock *sk)
202 {
203 	struct tls_context *tls_ctx = tls_get_ctx(sk);
204 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
205 
206 	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
207 		&ctx->sg_encrypted_size);
208 
209 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
210 		&ctx->sg_plaintext_size);
211 }
212 
213 static int tls_do_encryption(struct tls_context *tls_ctx,
214 			     struct tls_sw_context *ctx, size_t data_len,
215 			     gfp_t flags)
216 {
217 	unsigned int req_size = sizeof(struct aead_request) +
218 		crypto_aead_reqsize(ctx->aead_send);
219 	struct aead_request *aead_req;
220 	int rc;
221 
222 	aead_req = kmalloc(req_size, flags);
223 	if (!aead_req)
224 		return -ENOMEM;
225 
226 	ctx->sg_encrypted_data[0].offset += tls_ctx->prepend_size;
227 	ctx->sg_encrypted_data[0].length -= tls_ctx->prepend_size;
228 
229 	aead_request_set_tfm(aead_req, ctx->aead_send);
230 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
231 	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
232 			       data_len, tls_ctx->iv);
233 	rc = crypto_aead_encrypt(aead_req);
234 
235 	ctx->sg_encrypted_data[0].offset -= tls_ctx->prepend_size;
236 	ctx->sg_encrypted_data[0].length += tls_ctx->prepend_size;
237 
238 	kfree(aead_req);
239 	return rc;
240 }
241 
242 static int tls_push_record(struct sock *sk, int flags,
243 			   unsigned char record_type)
244 {
245 	struct tls_context *tls_ctx = tls_get_ctx(sk);
246 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
247 	int rc;
248 
249 	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
250 	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
251 
252 	tls_make_aad(0, ctx->aad_space, ctx->sg_plaintext_size,
253 		     tls_ctx->rec_seq, tls_ctx->rec_seq_size,
254 		     record_type);
255 
256 	tls_fill_prepend(tls_ctx,
257 			 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
258 			 ctx->sg_encrypted_data[0].offset,
259 			 ctx->sg_plaintext_size, record_type);
260 
261 	tls_ctx->pending_open_record_frags = 0;
262 	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
263 
264 	rc = tls_do_encryption(tls_ctx, ctx, ctx->sg_plaintext_size,
265 			       sk->sk_allocation);
266 	if (rc < 0) {
267 		/* If we are called from write_space and
268 		 * we fail, we need to set this SOCK_NOSPACE
269 		 * to trigger another write_space in the future.
270 		 */
271 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
272 		return rc;
273 	}
274 
275 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
276 		&ctx->sg_plaintext_size);
277 
278 	ctx->sg_encrypted_num_elem = 0;
279 	ctx->sg_encrypted_size = 0;
280 
281 	/* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
282 	rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
283 	if (rc < 0 && rc != -EAGAIN)
284 		tls_err_abort(sk);
285 
286 	tls_advance_record_sn(sk, tls_ctx);
287 	return rc;
288 }
289 
290 static int tls_sw_push_pending_record(struct sock *sk, int flags)
291 {
292 	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
293 }
294 
295 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
296 			      int length)
297 {
298 	struct tls_context *tls_ctx = tls_get_ctx(sk);
299 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
300 	struct page *pages[MAX_SKB_FRAGS];
301 
302 	size_t offset;
303 	ssize_t copied, use;
304 	int i = 0;
305 	unsigned int size = ctx->sg_plaintext_size;
306 	int num_elem = ctx->sg_plaintext_num_elem;
307 	int rc = 0;
308 	int maxpages;
309 
310 	while (length > 0) {
311 		i = 0;
312 		maxpages = ARRAY_SIZE(ctx->sg_plaintext_data) - num_elem;
313 		if (maxpages == 0) {
314 			rc = -EFAULT;
315 			goto out;
316 		}
317 		copied = iov_iter_get_pages(from, pages,
318 					    length,
319 					    maxpages, &offset);
320 		if (copied <= 0) {
321 			rc = -EFAULT;
322 			goto out;
323 		}
324 
325 		iov_iter_advance(from, copied);
326 
327 		length -= copied;
328 		size += copied;
329 		while (copied) {
330 			use = min_t(int, copied, PAGE_SIZE - offset);
331 
332 			sg_set_page(&ctx->sg_plaintext_data[num_elem],
333 				    pages[i], use, offset);
334 			sg_unmark_end(&ctx->sg_plaintext_data[num_elem]);
335 			sk_mem_charge(sk, use);
336 
337 			offset = 0;
338 			copied -= use;
339 
340 			++i;
341 			++num_elem;
342 		}
343 	}
344 
345 out:
346 	ctx->sg_plaintext_size = size;
347 	ctx->sg_plaintext_num_elem = num_elem;
348 	return rc;
349 }
350 
351 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
352 			     int bytes)
353 {
354 	struct tls_context *tls_ctx = tls_get_ctx(sk);
355 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
356 	struct scatterlist *sg = ctx->sg_plaintext_data;
357 	int copy, i, rc = 0;
358 
359 	for (i = tls_ctx->pending_open_record_frags;
360 	     i < ctx->sg_plaintext_num_elem; ++i) {
361 		copy = sg[i].length;
362 		if (copy_from_iter(
363 				page_address(sg_page(&sg[i])) + sg[i].offset,
364 				copy, from) != copy) {
365 			rc = -EFAULT;
366 			goto out;
367 		}
368 		bytes -= copy;
369 
370 		++tls_ctx->pending_open_record_frags;
371 
372 		if (!bytes)
373 			break;
374 	}
375 
376 out:
377 	return rc;
378 }
379 
380 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
381 {
382 	struct tls_context *tls_ctx = tls_get_ctx(sk);
383 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
384 	int ret = 0;
385 	int required_size;
386 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
387 	bool eor = !(msg->msg_flags & MSG_MORE);
388 	size_t try_to_copy, copied = 0;
389 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
390 	int record_room;
391 	bool full_record;
392 	int orig_size;
393 
394 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
395 		return -ENOTSUPP;
396 
397 	lock_sock(sk);
398 
399 	if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
400 		goto send_end;
401 
402 	if (unlikely(msg->msg_controllen)) {
403 		ret = tls_proccess_cmsg(sk, msg, &record_type);
404 		if (ret)
405 			goto send_end;
406 	}
407 
408 	while (msg_data_left(msg)) {
409 		if (sk->sk_err) {
410 			ret = sk->sk_err;
411 			goto send_end;
412 		}
413 
414 		orig_size = ctx->sg_plaintext_size;
415 		full_record = false;
416 		try_to_copy = msg_data_left(msg);
417 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
418 		if (try_to_copy >= record_room) {
419 			try_to_copy = record_room;
420 			full_record = true;
421 		}
422 
423 		required_size = ctx->sg_plaintext_size + try_to_copy +
424 				tls_ctx->overhead_size;
425 
426 		if (!sk_stream_memory_free(sk))
427 			goto wait_for_sndbuf;
428 alloc_encrypted:
429 		ret = alloc_encrypted_sg(sk, required_size);
430 		if (ret) {
431 			if (ret != -ENOSPC)
432 				goto wait_for_memory;
433 
434 			/* Adjust try_to_copy according to the amount that was
435 			 * actually allocated. The difference is due
436 			 * to max sg elements limit
437 			 */
438 			try_to_copy -= required_size - ctx->sg_encrypted_size;
439 			full_record = true;
440 		}
441 
442 		if (full_record || eor) {
443 			ret = zerocopy_from_iter(sk, &msg->msg_iter,
444 						 try_to_copy);
445 			if (ret)
446 				goto fallback_to_reg_send;
447 
448 			copied += try_to_copy;
449 			ret = tls_push_record(sk, msg->msg_flags, record_type);
450 			if (!ret)
451 				continue;
452 			if (ret == -EAGAIN)
453 				goto send_end;
454 
455 			copied -= try_to_copy;
456 fallback_to_reg_send:
457 			iov_iter_revert(&msg->msg_iter,
458 					ctx->sg_plaintext_size - orig_size);
459 			trim_sg(sk, ctx->sg_plaintext_data,
460 				&ctx->sg_plaintext_num_elem,
461 				&ctx->sg_plaintext_size,
462 				orig_size);
463 		}
464 
465 		required_size = ctx->sg_plaintext_size + try_to_copy;
466 alloc_plaintext:
467 		ret = alloc_plaintext_sg(sk, required_size);
468 		if (ret) {
469 			if (ret != -ENOSPC)
470 				goto wait_for_memory;
471 
472 			/* Adjust try_to_copy according to the amount that was
473 			 * actually allocated. The difference is due
474 			 * to max sg elements limit
475 			 */
476 			try_to_copy -= required_size - ctx->sg_plaintext_size;
477 			full_record = true;
478 
479 			trim_sg(sk, ctx->sg_encrypted_data,
480 				&ctx->sg_encrypted_num_elem,
481 				&ctx->sg_encrypted_size,
482 				ctx->sg_plaintext_size +
483 				tls_ctx->overhead_size);
484 		}
485 
486 		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
487 		if (ret)
488 			goto trim_sgl;
489 
490 		copied += try_to_copy;
491 		if (full_record || eor) {
492 push_record:
493 			ret = tls_push_record(sk, msg->msg_flags, record_type);
494 			if (ret) {
495 				if (ret == -ENOMEM)
496 					goto wait_for_memory;
497 
498 				goto send_end;
499 			}
500 		}
501 
502 		continue;
503 
504 wait_for_sndbuf:
505 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
506 wait_for_memory:
507 		ret = sk_stream_wait_memory(sk, &timeo);
508 		if (ret) {
509 trim_sgl:
510 			trim_both_sgl(sk, orig_size);
511 			goto send_end;
512 		}
513 
514 		if (tls_is_pending_closed_record(tls_ctx))
515 			goto push_record;
516 
517 		if (ctx->sg_encrypted_size < required_size)
518 			goto alloc_encrypted;
519 
520 		goto alloc_plaintext;
521 	}
522 
523 send_end:
524 	ret = sk_stream_error(sk, msg->msg_flags, ret);
525 
526 	release_sock(sk);
527 	return copied ? copied : ret;
528 }
529 
530 int tls_sw_sendpage(struct sock *sk, struct page *page,
531 		    int offset, size_t size, int flags)
532 {
533 	struct tls_context *tls_ctx = tls_get_ctx(sk);
534 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
535 	int ret = 0;
536 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
537 	bool eor;
538 	size_t orig_size = size;
539 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
540 	struct scatterlist *sg;
541 	bool full_record;
542 	int record_room;
543 
544 	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
545 		      MSG_SENDPAGE_NOTLAST))
546 		return -ENOTSUPP;
547 
548 	/* No MSG_EOR from splice, only look at MSG_MORE */
549 	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
550 
551 	lock_sock(sk);
552 
553 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
554 
555 	if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
556 		goto sendpage_end;
557 
558 	/* Call the sk_stream functions to manage the sndbuf mem. */
559 	while (size > 0) {
560 		size_t copy, required_size;
561 
562 		if (sk->sk_err) {
563 			ret = sk->sk_err;
564 			goto sendpage_end;
565 		}
566 
567 		full_record = false;
568 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
569 		copy = size;
570 		if (copy >= record_room) {
571 			copy = record_room;
572 			full_record = true;
573 		}
574 		required_size = ctx->sg_plaintext_size + copy +
575 			      tls_ctx->overhead_size;
576 
577 		if (!sk_stream_memory_free(sk))
578 			goto wait_for_sndbuf;
579 alloc_payload:
580 		ret = alloc_encrypted_sg(sk, required_size);
581 		if (ret) {
582 			if (ret != -ENOSPC)
583 				goto wait_for_memory;
584 
585 			/* Adjust copy according to the amount that was
586 			 * actually allocated. The difference is due
587 			 * to max sg elements limit
588 			 */
589 			copy -= required_size - ctx->sg_plaintext_size;
590 			full_record = true;
591 		}
592 
593 		get_page(page);
594 		sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
595 		sg_set_page(sg, page, copy, offset);
596 		ctx->sg_plaintext_num_elem++;
597 
598 		sk_mem_charge(sk, copy);
599 		offset += copy;
600 		size -= copy;
601 		ctx->sg_plaintext_size += copy;
602 		tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
603 
604 		if (full_record || eor ||
605 		    ctx->sg_plaintext_num_elem ==
606 		    ARRAY_SIZE(ctx->sg_plaintext_data)) {
607 push_record:
608 			ret = tls_push_record(sk, flags, record_type);
609 			if (ret) {
610 				if (ret == -ENOMEM)
611 					goto wait_for_memory;
612 
613 				goto sendpage_end;
614 			}
615 		}
616 		continue;
617 wait_for_sndbuf:
618 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
619 wait_for_memory:
620 		ret = sk_stream_wait_memory(sk, &timeo);
621 		if (ret) {
622 			trim_both_sgl(sk, ctx->sg_plaintext_size);
623 			goto sendpage_end;
624 		}
625 
626 		if (tls_is_pending_closed_record(tls_ctx))
627 			goto push_record;
628 
629 		goto alloc_payload;
630 	}
631 
632 sendpage_end:
633 	if (orig_size > size)
634 		ret = orig_size - size;
635 	else
636 		ret = sk_stream_error(sk, flags, ret);
637 
638 	release_sock(sk);
639 	return ret;
640 }
641 
642 void tls_sw_free_resources(struct sock *sk)
643 {
644 	struct tls_context *tls_ctx = tls_get_ctx(sk);
645 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
646 
647 	if (ctx->aead_send)
648 		crypto_free_aead(ctx->aead_send);
649 
650 	tls_free_both_sg(sk);
651 
652 	kfree(ctx);
653 }
654 
655 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
656 {
657 	char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
658 	struct tls_crypto_info *crypto_info;
659 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
660 	struct tls_sw_context *sw_ctx;
661 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
662 	char *iv, *rec_seq;
663 	int rc = 0;
664 
665 	if (!ctx) {
666 		rc = -EINVAL;
667 		goto out;
668 	}
669 
670 	if (ctx->priv_ctx) {
671 		rc = -EEXIST;
672 		goto out;
673 	}
674 
675 	sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
676 	if (!sw_ctx) {
677 		rc = -ENOMEM;
678 		goto out;
679 	}
680 
681 	ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
682 	ctx->free_resources = tls_sw_free_resources;
683 
684 	crypto_info = &ctx->crypto_send;
685 	switch (crypto_info->cipher_type) {
686 	case TLS_CIPHER_AES_GCM_128: {
687 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
688 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
689 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
690 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
691 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
692 		rec_seq =
693 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
694 		gcm_128_info =
695 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
696 		break;
697 	}
698 	default:
699 		rc = -EINVAL;
700 		goto out;
701 	}
702 
703 	ctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
704 	ctx->tag_size = tag_size;
705 	ctx->overhead_size = ctx->prepend_size + ctx->tag_size;
706 	ctx->iv_size = iv_size;
707 	ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
708 			  GFP_KERNEL);
709 	if (!ctx->iv) {
710 		rc = -ENOMEM;
711 		goto out;
712 	}
713 	memcpy(ctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
714 	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
715 	ctx->rec_seq_size = rec_seq_size;
716 	ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
717 	if (!ctx->rec_seq) {
718 		rc = -ENOMEM;
719 		goto free_iv;
720 	}
721 	memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
722 
723 	sg_init_table(sw_ctx->sg_encrypted_data,
724 		      ARRAY_SIZE(sw_ctx->sg_encrypted_data));
725 	sg_init_table(sw_ctx->sg_plaintext_data,
726 		      ARRAY_SIZE(sw_ctx->sg_plaintext_data));
727 
728 	sg_init_table(sw_ctx->sg_aead_in, 2);
729 	sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
730 		   sizeof(sw_ctx->aad_space));
731 	sg_unmark_end(&sw_ctx->sg_aead_in[1]);
732 	sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
733 	sg_init_table(sw_ctx->sg_aead_out, 2);
734 	sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
735 		   sizeof(sw_ctx->aad_space));
736 	sg_unmark_end(&sw_ctx->sg_aead_out[1]);
737 	sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
738 
739 	if (!sw_ctx->aead_send) {
740 		sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0);
741 		if (IS_ERR(sw_ctx->aead_send)) {
742 			rc = PTR_ERR(sw_ctx->aead_send);
743 			sw_ctx->aead_send = NULL;
744 			goto free_rec_seq;
745 		}
746 	}
747 
748 	ctx->push_pending_record = tls_sw_push_pending_record;
749 
750 	memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
751 
752 	rc = crypto_aead_setkey(sw_ctx->aead_send, keyval,
753 				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
754 	if (rc)
755 		goto free_aead;
756 
757 	rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tag_size);
758 	if (!rc)
759 		goto out;
760 
761 free_aead:
762 	crypto_free_aead(sw_ctx->aead_send);
763 	sw_ctx->aead_send = NULL;
764 free_rec_seq:
765 	kfree(ctx->rec_seq);
766 	ctx->rec_seq = NULL;
767 free_iv:
768 	kfree(ctx->iv);
769 	ctx->iv = NULL;
770 out:
771 	return rc;
772 }
773