xref: /openbmc/linux/arch/arm64/crypto/aes-glue.c (revision 8795a739)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/ctr.h>
13 #include <crypto/sha.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/scatterwalk.h>
18 #include <linux/module.h>
19 #include <linux/cpufeature.h>
20 #include <crypto/xts.h>
21 
22 #include "aes-ce-setkey.h"
23 
24 #ifdef USE_V8_CRYPTO_EXTENSIONS
25 #define MODE			"ce"
26 #define PRIO			300
27 #define aes_expandkey		ce_aes_expandkey
28 #define aes_ecb_encrypt		ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
34 #define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
35 #define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
36 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
37 #define aes_xts_encrypt		ce_aes_xts_encrypt
38 #define aes_xts_decrypt		ce_aes_xts_decrypt
39 #define aes_mac_update		ce_aes_mac_update
40 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
41 #else
42 #define MODE			"neon"
43 #define PRIO			200
44 #define aes_ecb_encrypt		neon_aes_ecb_encrypt
45 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
46 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
47 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
48 #define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
49 #define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
50 #define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
51 #define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
52 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
53 #define aes_xts_encrypt		neon_aes_xts_encrypt
54 #define aes_xts_decrypt		neon_aes_xts_decrypt
55 #define aes_mac_update		neon_aes_mac_update
56 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
57 #endif
58 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
59 MODULE_ALIAS_CRYPTO("ecb(aes)");
60 MODULE_ALIAS_CRYPTO("cbc(aes)");
61 MODULE_ALIAS_CRYPTO("ctr(aes)");
62 MODULE_ALIAS_CRYPTO("xts(aes)");
63 #endif
64 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
65 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
66 MODULE_ALIAS_CRYPTO("cmac(aes)");
67 MODULE_ALIAS_CRYPTO("xcbc(aes)");
68 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
69 
70 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
71 MODULE_LICENSE("GPL v2");
72 
73 /* defined in aes-modes.S */
74 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
75 				int rounds, int blocks);
76 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
77 				int rounds, int blocks);
78 
79 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
80 				int rounds, int blocks, u8 iv[]);
81 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
82 				int rounds, int blocks, u8 iv[]);
83 
84 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
85 				int rounds, int bytes, u8 const iv[]);
86 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
87 				int rounds, int bytes, u8 const iv[]);
88 
89 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
90 				int rounds, int blocks, u8 ctr[]);
91 
92 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
93 				int rounds, int bytes, u32 const rk2[], u8 iv[],
94 				int first);
95 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
96 				int rounds, int bytes, u32 const rk2[], u8 iv[],
97 				int first);
98 
99 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
100 				      int rounds, int blocks, u8 iv[],
101 				      u32 const rk2[]);
102 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
103 				      int rounds, int blocks, u8 iv[],
104 				      u32 const rk2[]);
105 
106 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
107 			       int blocks, u8 dg[], int enc_before,
108 			       int enc_after);
109 
110 struct crypto_aes_xts_ctx {
111 	struct crypto_aes_ctx key1;
112 	struct crypto_aes_ctx __aligned(8) key2;
113 };
114 
115 struct crypto_aes_essiv_cbc_ctx {
116 	struct crypto_aes_ctx key1;
117 	struct crypto_aes_ctx __aligned(8) key2;
118 	struct crypto_shash *hash;
119 };
120 
121 struct mac_tfm_ctx {
122 	struct crypto_aes_ctx key;
123 	u8 __aligned(8) consts[];
124 };
125 
126 struct mac_desc_ctx {
127 	unsigned int len;
128 	u8 dg[AES_BLOCK_SIZE];
129 };
130 
131 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
132 			       unsigned int key_len)
133 {
134 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
135 	int ret;
136 
137 	ret = aes_expandkey(ctx, in_key, key_len);
138 	if (ret)
139 		crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
140 
141 	return ret;
142 }
143 
144 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
145 				      const u8 *in_key, unsigned int key_len)
146 {
147 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
148 	int ret;
149 
150 	ret = xts_verify_key(tfm, in_key, key_len);
151 	if (ret)
152 		return ret;
153 
154 	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
155 	if (!ret)
156 		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
157 				    key_len / 2);
158 	if (!ret)
159 		return 0;
160 
161 	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
162 	return -EINVAL;
163 }
164 
165 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
166 					    const u8 *in_key,
167 					    unsigned int key_len)
168 {
169 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
170 	SHASH_DESC_ON_STACK(desc, ctx->hash);
171 	u8 digest[SHA256_DIGEST_SIZE];
172 	int ret;
173 
174 	ret = aes_expandkey(&ctx->key1, in_key, key_len);
175 	if (ret)
176 		goto out;
177 
178 	desc->tfm = ctx->hash;
179 	crypto_shash_digest(desc, in_key, key_len, digest);
180 
181 	ret = aes_expandkey(&ctx->key2, digest, sizeof(digest));
182 	if (ret)
183 		goto out;
184 
185 	return 0;
186 out:
187 	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
188 	return -EINVAL;
189 }
190 
191 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
192 {
193 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
195 	int err, rounds = 6 + ctx->key_length / 4;
196 	struct skcipher_walk walk;
197 	unsigned int blocks;
198 
199 	err = skcipher_walk_virt(&walk, req, false);
200 
201 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
202 		kernel_neon_begin();
203 		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
204 				ctx->key_enc, rounds, blocks);
205 		kernel_neon_end();
206 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
207 	}
208 	return err;
209 }
210 
211 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
212 {
213 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
214 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
215 	int err, rounds = 6 + ctx->key_length / 4;
216 	struct skcipher_walk walk;
217 	unsigned int blocks;
218 
219 	err = skcipher_walk_virt(&walk, req, false);
220 
221 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
222 		kernel_neon_begin();
223 		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
224 				ctx->key_dec, rounds, blocks);
225 		kernel_neon_end();
226 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
227 	}
228 	return err;
229 }
230 
231 static int cbc_encrypt_walk(struct skcipher_request *req,
232 			    struct skcipher_walk *walk)
233 {
234 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
235 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
236 	int err = 0, rounds = 6 + ctx->key_length / 4;
237 	unsigned int blocks;
238 
239 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
240 		kernel_neon_begin();
241 		aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
242 				ctx->key_enc, rounds, blocks, walk->iv);
243 		kernel_neon_end();
244 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
245 	}
246 	return err;
247 }
248 
249 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
250 {
251 	struct skcipher_walk walk;
252 	int err;
253 
254 	err = skcipher_walk_virt(&walk, req, false);
255 	if (err)
256 		return err;
257 	return cbc_encrypt_walk(req, &walk);
258 }
259 
260 static int cbc_decrypt_walk(struct skcipher_request *req,
261 			    struct skcipher_walk *walk)
262 {
263 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
264 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
265 	int err = 0, rounds = 6 + ctx->key_length / 4;
266 	unsigned int blocks;
267 
268 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
269 		kernel_neon_begin();
270 		aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
271 				ctx->key_dec, rounds, blocks, walk->iv);
272 		kernel_neon_end();
273 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
274 	}
275 	return err;
276 }
277 
278 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
279 {
280 	struct skcipher_walk walk;
281 	int err;
282 
283 	err = skcipher_walk_virt(&walk, req, false);
284 	if (err)
285 		return err;
286 	return cbc_decrypt_walk(req, &walk);
287 }
288 
289 static int cts_cbc_encrypt(struct skcipher_request *req)
290 {
291 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
292 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
293 	int err, rounds = 6 + ctx->key_length / 4;
294 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
295 	struct scatterlist *src = req->src, *dst = req->dst;
296 	struct scatterlist sg_src[2], sg_dst[2];
297 	struct skcipher_request subreq;
298 	struct skcipher_walk walk;
299 
300 	skcipher_request_set_tfm(&subreq, tfm);
301 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
302 				      NULL, NULL);
303 
304 	if (req->cryptlen <= AES_BLOCK_SIZE) {
305 		if (req->cryptlen < AES_BLOCK_SIZE)
306 			return -EINVAL;
307 		cbc_blocks = 1;
308 	}
309 
310 	if (cbc_blocks > 0) {
311 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
312 					   cbc_blocks * AES_BLOCK_SIZE,
313 					   req->iv);
314 
315 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
316 		      cbc_encrypt_walk(&subreq, &walk);
317 		if (err)
318 			return err;
319 
320 		if (req->cryptlen == AES_BLOCK_SIZE)
321 			return 0;
322 
323 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
324 		if (req->dst != req->src)
325 			dst = scatterwalk_ffwd(sg_dst, req->dst,
326 					       subreq.cryptlen);
327 	}
328 
329 	/* handle ciphertext stealing */
330 	skcipher_request_set_crypt(&subreq, src, dst,
331 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
332 				   req->iv);
333 
334 	err = skcipher_walk_virt(&walk, &subreq, false);
335 	if (err)
336 		return err;
337 
338 	kernel_neon_begin();
339 	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
340 			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
341 	kernel_neon_end();
342 
343 	return skcipher_walk_done(&walk, 0);
344 }
345 
346 static int cts_cbc_decrypt(struct skcipher_request *req)
347 {
348 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
349 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
350 	int err, rounds = 6 + ctx->key_length / 4;
351 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
352 	struct scatterlist *src = req->src, *dst = req->dst;
353 	struct scatterlist sg_src[2], sg_dst[2];
354 	struct skcipher_request subreq;
355 	struct skcipher_walk walk;
356 
357 	skcipher_request_set_tfm(&subreq, tfm);
358 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
359 				      NULL, NULL);
360 
361 	if (req->cryptlen <= AES_BLOCK_SIZE) {
362 		if (req->cryptlen < AES_BLOCK_SIZE)
363 			return -EINVAL;
364 		cbc_blocks = 1;
365 	}
366 
367 	if (cbc_blocks > 0) {
368 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
369 					   cbc_blocks * AES_BLOCK_SIZE,
370 					   req->iv);
371 
372 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
373 		      cbc_decrypt_walk(&subreq, &walk);
374 		if (err)
375 			return err;
376 
377 		if (req->cryptlen == AES_BLOCK_SIZE)
378 			return 0;
379 
380 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
381 		if (req->dst != req->src)
382 			dst = scatterwalk_ffwd(sg_dst, req->dst,
383 					       subreq.cryptlen);
384 	}
385 
386 	/* handle ciphertext stealing */
387 	skcipher_request_set_crypt(&subreq, src, dst,
388 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
389 				   req->iv);
390 
391 	err = skcipher_walk_virt(&walk, &subreq, false);
392 	if (err)
393 		return err;
394 
395 	kernel_neon_begin();
396 	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
397 			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
398 	kernel_neon_end();
399 
400 	return skcipher_walk_done(&walk, 0);
401 }
402 
403 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
404 {
405 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
406 
407 	ctx->hash = crypto_alloc_shash("sha256", 0, 0);
408 
409 	return PTR_ERR_OR_ZERO(ctx->hash);
410 }
411 
412 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
413 {
414 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
415 
416 	crypto_free_shash(ctx->hash);
417 }
418 
419 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
420 {
421 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
422 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
423 	int err, rounds = 6 + ctx->key1.key_length / 4;
424 	struct skcipher_walk walk;
425 	unsigned int blocks;
426 
427 	err = skcipher_walk_virt(&walk, req, false);
428 
429 	blocks = walk.nbytes / AES_BLOCK_SIZE;
430 	if (blocks) {
431 		kernel_neon_begin();
432 		aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
433 				      ctx->key1.key_enc, rounds, blocks,
434 				      req->iv, ctx->key2.key_enc);
435 		kernel_neon_end();
436 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
437 	}
438 	return err ?: cbc_encrypt_walk(req, &walk);
439 }
440 
441 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
442 {
443 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
444 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
445 	int err, rounds = 6 + ctx->key1.key_length / 4;
446 	struct skcipher_walk walk;
447 	unsigned int blocks;
448 
449 	err = skcipher_walk_virt(&walk, req, false);
450 
451 	blocks = walk.nbytes / AES_BLOCK_SIZE;
452 	if (blocks) {
453 		kernel_neon_begin();
454 		aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
455 				      ctx->key1.key_dec, rounds, blocks,
456 				      req->iv, ctx->key2.key_enc);
457 		kernel_neon_end();
458 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
459 	}
460 	return err ?: cbc_decrypt_walk(req, &walk);
461 }
462 
463 static int ctr_encrypt(struct skcipher_request *req)
464 {
465 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
466 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
467 	int err, rounds = 6 + ctx->key_length / 4;
468 	struct skcipher_walk walk;
469 	int blocks;
470 
471 	err = skcipher_walk_virt(&walk, req, false);
472 
473 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
474 		kernel_neon_begin();
475 		aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
476 				ctx->key_enc, rounds, blocks, walk.iv);
477 		kernel_neon_end();
478 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
479 	}
480 	if (walk.nbytes) {
481 		u8 __aligned(8) tail[AES_BLOCK_SIZE];
482 		unsigned int nbytes = walk.nbytes;
483 		u8 *tdst = walk.dst.virt.addr;
484 		u8 *tsrc = walk.src.virt.addr;
485 
486 		/*
487 		 * Tell aes_ctr_encrypt() to process a tail block.
488 		 */
489 		blocks = -1;
490 
491 		kernel_neon_begin();
492 		aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
493 				blocks, walk.iv);
494 		kernel_neon_end();
495 		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
496 		err = skcipher_walk_done(&walk, 0);
497 	}
498 
499 	return err;
500 }
501 
502 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
503 {
504 	const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
505 	unsigned long flags;
506 
507 	/*
508 	 * Temporarily disable interrupts to avoid races where
509 	 * cachelines are evicted when the CPU is interrupted
510 	 * to do something else.
511 	 */
512 	local_irq_save(flags);
513 	aes_encrypt(ctx, dst, src);
514 	local_irq_restore(flags);
515 }
516 
517 static int __maybe_unused ctr_encrypt_sync(struct skcipher_request *req)
518 {
519 	if (!crypto_simd_usable())
520 		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
521 
522 	return ctr_encrypt(req);
523 }
524 
525 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
526 {
527 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
528 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
529 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
530 	int tail = req->cryptlen % AES_BLOCK_SIZE;
531 	struct scatterlist sg_src[2], sg_dst[2];
532 	struct skcipher_request subreq;
533 	struct scatterlist *src, *dst;
534 	struct skcipher_walk walk;
535 
536 	if (req->cryptlen < AES_BLOCK_SIZE)
537 		return -EINVAL;
538 
539 	err = skcipher_walk_virt(&walk, req, false);
540 
541 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
542 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
543 					      AES_BLOCK_SIZE) - 2;
544 
545 		skcipher_walk_abort(&walk);
546 
547 		skcipher_request_set_tfm(&subreq, tfm);
548 		skcipher_request_set_callback(&subreq,
549 					      skcipher_request_flags(req),
550 					      NULL, NULL);
551 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
552 					   xts_blocks * AES_BLOCK_SIZE,
553 					   req->iv);
554 		req = &subreq;
555 		err = skcipher_walk_virt(&walk, req, false);
556 	} else {
557 		tail = 0;
558 	}
559 
560 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
561 		int nbytes = walk.nbytes;
562 
563 		if (walk.nbytes < walk.total)
564 			nbytes &= ~(AES_BLOCK_SIZE - 1);
565 
566 		kernel_neon_begin();
567 		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
568 				ctx->key1.key_enc, rounds, nbytes,
569 				ctx->key2.key_enc, walk.iv, first);
570 		kernel_neon_end();
571 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
572 	}
573 
574 	if (err || likely(!tail))
575 		return err;
576 
577 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
578 	if (req->dst != req->src)
579 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
580 
581 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
582 				   req->iv);
583 
584 	err = skcipher_walk_virt(&walk, &subreq, false);
585 	if (err)
586 		return err;
587 
588 	kernel_neon_begin();
589 	aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
590 			ctx->key1.key_enc, rounds, walk.nbytes,
591 			ctx->key2.key_enc, walk.iv, first);
592 	kernel_neon_end();
593 
594 	return skcipher_walk_done(&walk, 0);
595 }
596 
597 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
598 {
599 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
600 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
601 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
602 	int tail = req->cryptlen % AES_BLOCK_SIZE;
603 	struct scatterlist sg_src[2], sg_dst[2];
604 	struct skcipher_request subreq;
605 	struct scatterlist *src, *dst;
606 	struct skcipher_walk walk;
607 
608 	if (req->cryptlen < AES_BLOCK_SIZE)
609 		return -EINVAL;
610 
611 	err = skcipher_walk_virt(&walk, req, false);
612 
613 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
614 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
615 					      AES_BLOCK_SIZE) - 2;
616 
617 		skcipher_walk_abort(&walk);
618 
619 		skcipher_request_set_tfm(&subreq, tfm);
620 		skcipher_request_set_callback(&subreq,
621 					      skcipher_request_flags(req),
622 					      NULL, NULL);
623 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
624 					   xts_blocks * AES_BLOCK_SIZE,
625 					   req->iv);
626 		req = &subreq;
627 		err = skcipher_walk_virt(&walk, req, false);
628 	} else {
629 		tail = 0;
630 	}
631 
632 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
633 		int nbytes = walk.nbytes;
634 
635 		if (walk.nbytes < walk.total)
636 			nbytes &= ~(AES_BLOCK_SIZE - 1);
637 
638 		kernel_neon_begin();
639 		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
640 				ctx->key1.key_dec, rounds, nbytes,
641 				ctx->key2.key_enc, walk.iv, first);
642 		kernel_neon_end();
643 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
644 	}
645 
646 	if (err || likely(!tail))
647 		return err;
648 
649 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
650 	if (req->dst != req->src)
651 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
652 
653 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
654 				   req->iv);
655 
656 	err = skcipher_walk_virt(&walk, &subreq, false);
657 	if (err)
658 		return err;
659 
660 
661 	kernel_neon_begin();
662 	aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
663 			ctx->key1.key_dec, rounds, walk.nbytes,
664 			ctx->key2.key_enc, walk.iv, first);
665 	kernel_neon_end();
666 
667 	return skcipher_walk_done(&walk, 0);
668 }
669 
670 static struct skcipher_alg aes_algs[] = { {
671 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
672 	.base = {
673 		.cra_name		= "__ecb(aes)",
674 		.cra_driver_name	= "__ecb-aes-" MODE,
675 		.cra_priority		= PRIO,
676 		.cra_flags		= CRYPTO_ALG_INTERNAL,
677 		.cra_blocksize		= AES_BLOCK_SIZE,
678 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
679 		.cra_module		= THIS_MODULE,
680 	},
681 	.min_keysize	= AES_MIN_KEY_SIZE,
682 	.max_keysize	= AES_MAX_KEY_SIZE,
683 	.setkey		= skcipher_aes_setkey,
684 	.encrypt	= ecb_encrypt,
685 	.decrypt	= ecb_decrypt,
686 }, {
687 	.base = {
688 		.cra_name		= "__cbc(aes)",
689 		.cra_driver_name	= "__cbc-aes-" MODE,
690 		.cra_priority		= PRIO,
691 		.cra_flags		= CRYPTO_ALG_INTERNAL,
692 		.cra_blocksize		= AES_BLOCK_SIZE,
693 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
694 		.cra_module		= THIS_MODULE,
695 	},
696 	.min_keysize	= AES_MIN_KEY_SIZE,
697 	.max_keysize	= AES_MAX_KEY_SIZE,
698 	.ivsize		= AES_BLOCK_SIZE,
699 	.setkey		= skcipher_aes_setkey,
700 	.encrypt	= cbc_encrypt,
701 	.decrypt	= cbc_decrypt,
702 }, {
703 	.base = {
704 		.cra_name		= "__ctr(aes)",
705 		.cra_driver_name	= "__ctr-aes-" MODE,
706 		.cra_priority		= PRIO,
707 		.cra_flags		= CRYPTO_ALG_INTERNAL,
708 		.cra_blocksize		= 1,
709 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
710 		.cra_module		= THIS_MODULE,
711 	},
712 	.min_keysize	= AES_MIN_KEY_SIZE,
713 	.max_keysize	= AES_MAX_KEY_SIZE,
714 	.ivsize		= AES_BLOCK_SIZE,
715 	.chunksize	= AES_BLOCK_SIZE,
716 	.setkey		= skcipher_aes_setkey,
717 	.encrypt	= ctr_encrypt,
718 	.decrypt	= ctr_encrypt,
719 }, {
720 	.base = {
721 		.cra_name		= "ctr(aes)",
722 		.cra_driver_name	= "ctr-aes-" MODE,
723 		.cra_priority		= PRIO - 1,
724 		.cra_blocksize		= 1,
725 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
726 		.cra_module		= THIS_MODULE,
727 	},
728 	.min_keysize	= AES_MIN_KEY_SIZE,
729 	.max_keysize	= AES_MAX_KEY_SIZE,
730 	.ivsize		= AES_BLOCK_SIZE,
731 	.chunksize	= AES_BLOCK_SIZE,
732 	.setkey		= skcipher_aes_setkey,
733 	.encrypt	= ctr_encrypt_sync,
734 	.decrypt	= ctr_encrypt_sync,
735 }, {
736 	.base = {
737 		.cra_name		= "__xts(aes)",
738 		.cra_driver_name	= "__xts-aes-" MODE,
739 		.cra_priority		= PRIO,
740 		.cra_flags		= CRYPTO_ALG_INTERNAL,
741 		.cra_blocksize		= AES_BLOCK_SIZE,
742 		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
743 		.cra_module		= THIS_MODULE,
744 	},
745 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
746 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
747 	.ivsize		= AES_BLOCK_SIZE,
748 	.walksize	= 2 * AES_BLOCK_SIZE,
749 	.setkey		= xts_set_key,
750 	.encrypt	= xts_encrypt,
751 	.decrypt	= xts_decrypt,
752 }, {
753 #endif
754 	.base = {
755 		.cra_name		= "__cts(cbc(aes))",
756 		.cra_driver_name	= "__cts-cbc-aes-" MODE,
757 		.cra_priority		= PRIO,
758 		.cra_flags		= CRYPTO_ALG_INTERNAL,
759 		.cra_blocksize		= AES_BLOCK_SIZE,
760 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
761 		.cra_module		= THIS_MODULE,
762 	},
763 	.min_keysize	= AES_MIN_KEY_SIZE,
764 	.max_keysize	= AES_MAX_KEY_SIZE,
765 	.ivsize		= AES_BLOCK_SIZE,
766 	.walksize	= 2 * AES_BLOCK_SIZE,
767 	.setkey		= skcipher_aes_setkey,
768 	.encrypt	= cts_cbc_encrypt,
769 	.decrypt	= cts_cbc_decrypt,
770 }, {
771 	.base = {
772 		.cra_name		= "__essiv(cbc(aes),sha256)",
773 		.cra_driver_name	= "__essiv-cbc-aes-sha256-" MODE,
774 		.cra_priority		= PRIO + 1,
775 		.cra_flags		= CRYPTO_ALG_INTERNAL,
776 		.cra_blocksize		= AES_BLOCK_SIZE,
777 		.cra_ctxsize		= sizeof(struct crypto_aes_essiv_cbc_ctx),
778 		.cra_module		= THIS_MODULE,
779 	},
780 	.min_keysize	= AES_MIN_KEY_SIZE,
781 	.max_keysize	= AES_MAX_KEY_SIZE,
782 	.ivsize		= AES_BLOCK_SIZE,
783 	.setkey		= essiv_cbc_set_key,
784 	.encrypt	= essiv_cbc_encrypt,
785 	.decrypt	= essiv_cbc_decrypt,
786 	.init		= essiv_cbc_init_tfm,
787 	.exit		= essiv_cbc_exit_tfm,
788 } };
789 
790 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
791 			 unsigned int key_len)
792 {
793 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
794 	int err;
795 
796 	err = aes_expandkey(&ctx->key, in_key, key_len);
797 	if (err)
798 		crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
799 
800 	return err;
801 }
802 
803 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
804 {
805 	u64 a = be64_to_cpu(x->a);
806 	u64 b = be64_to_cpu(x->b);
807 
808 	y->a = cpu_to_be64((a << 1) | (b >> 63));
809 	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
810 }
811 
812 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
813 		       unsigned int key_len)
814 {
815 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
816 	be128 *consts = (be128 *)ctx->consts;
817 	int rounds = 6 + key_len / 4;
818 	int err;
819 
820 	err = cbcmac_setkey(tfm, in_key, key_len);
821 	if (err)
822 		return err;
823 
824 	/* encrypt the zero vector */
825 	kernel_neon_begin();
826 	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
827 			rounds, 1);
828 	kernel_neon_end();
829 
830 	cmac_gf128_mul_by_x(consts, consts);
831 	cmac_gf128_mul_by_x(consts + 1, consts);
832 
833 	return 0;
834 }
835 
836 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
837 		       unsigned int key_len)
838 {
839 	static u8 const ks[3][AES_BLOCK_SIZE] = {
840 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
841 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
842 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
843 	};
844 
845 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
846 	int rounds = 6 + key_len / 4;
847 	u8 key[AES_BLOCK_SIZE];
848 	int err;
849 
850 	err = cbcmac_setkey(tfm, in_key, key_len);
851 	if (err)
852 		return err;
853 
854 	kernel_neon_begin();
855 	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
856 	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
857 	kernel_neon_end();
858 
859 	return cbcmac_setkey(tfm, key, sizeof(key));
860 }
861 
862 static int mac_init(struct shash_desc *desc)
863 {
864 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
865 
866 	memset(ctx->dg, 0, AES_BLOCK_SIZE);
867 	ctx->len = 0;
868 
869 	return 0;
870 }
871 
872 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
873 			  u8 dg[], int enc_before, int enc_after)
874 {
875 	int rounds = 6 + ctx->key_length / 4;
876 
877 	if (crypto_simd_usable()) {
878 		kernel_neon_begin();
879 		aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
880 			       enc_after);
881 		kernel_neon_end();
882 	} else {
883 		if (enc_before)
884 			aes_encrypt(ctx, dg, dg);
885 
886 		while (blocks--) {
887 			crypto_xor(dg, in, AES_BLOCK_SIZE);
888 			in += AES_BLOCK_SIZE;
889 
890 			if (blocks || enc_after)
891 				aes_encrypt(ctx, dg, dg);
892 		}
893 	}
894 }
895 
896 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
897 {
898 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
899 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
900 
901 	while (len > 0) {
902 		unsigned int l;
903 
904 		if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
905 		    (ctx->len + len) > AES_BLOCK_SIZE) {
906 
907 			int blocks = len / AES_BLOCK_SIZE;
908 
909 			len %= AES_BLOCK_SIZE;
910 
911 			mac_do_update(&tctx->key, p, blocks, ctx->dg,
912 				      (ctx->len != 0), (len != 0));
913 
914 			p += blocks * AES_BLOCK_SIZE;
915 
916 			if (!len) {
917 				ctx->len = AES_BLOCK_SIZE;
918 				break;
919 			}
920 			ctx->len = 0;
921 		}
922 
923 		l = min(len, AES_BLOCK_SIZE - ctx->len);
924 
925 		if (l <= AES_BLOCK_SIZE) {
926 			crypto_xor(ctx->dg + ctx->len, p, l);
927 			ctx->len += l;
928 			len -= l;
929 			p += l;
930 		}
931 	}
932 
933 	return 0;
934 }
935 
936 static int cbcmac_final(struct shash_desc *desc, u8 *out)
937 {
938 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
939 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
940 
941 	mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
942 
943 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
944 
945 	return 0;
946 }
947 
948 static int cmac_final(struct shash_desc *desc, u8 *out)
949 {
950 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
951 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
952 	u8 *consts = tctx->consts;
953 
954 	if (ctx->len != AES_BLOCK_SIZE) {
955 		ctx->dg[ctx->len] ^= 0x80;
956 		consts += AES_BLOCK_SIZE;
957 	}
958 
959 	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
960 
961 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
962 
963 	return 0;
964 }
965 
966 static struct shash_alg mac_algs[] = { {
967 	.base.cra_name		= "cmac(aes)",
968 	.base.cra_driver_name	= "cmac-aes-" MODE,
969 	.base.cra_priority	= PRIO,
970 	.base.cra_blocksize	= AES_BLOCK_SIZE,
971 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
972 				  2 * AES_BLOCK_SIZE,
973 	.base.cra_module	= THIS_MODULE,
974 
975 	.digestsize		= AES_BLOCK_SIZE,
976 	.init			= mac_init,
977 	.update			= mac_update,
978 	.final			= cmac_final,
979 	.setkey			= cmac_setkey,
980 	.descsize		= sizeof(struct mac_desc_ctx),
981 }, {
982 	.base.cra_name		= "xcbc(aes)",
983 	.base.cra_driver_name	= "xcbc-aes-" MODE,
984 	.base.cra_priority	= PRIO,
985 	.base.cra_blocksize	= AES_BLOCK_SIZE,
986 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
987 				  2 * AES_BLOCK_SIZE,
988 	.base.cra_module	= THIS_MODULE,
989 
990 	.digestsize		= AES_BLOCK_SIZE,
991 	.init			= mac_init,
992 	.update			= mac_update,
993 	.final			= cmac_final,
994 	.setkey			= xcbc_setkey,
995 	.descsize		= sizeof(struct mac_desc_ctx),
996 }, {
997 	.base.cra_name		= "cbcmac(aes)",
998 	.base.cra_driver_name	= "cbcmac-aes-" MODE,
999 	.base.cra_priority	= PRIO,
1000 	.base.cra_blocksize	= 1,
1001 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
1002 	.base.cra_module	= THIS_MODULE,
1003 
1004 	.digestsize		= AES_BLOCK_SIZE,
1005 	.init			= mac_init,
1006 	.update			= mac_update,
1007 	.final			= cbcmac_final,
1008 	.setkey			= cbcmac_setkey,
1009 	.descsize		= sizeof(struct mac_desc_ctx),
1010 } };
1011 
1012 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
1013 
1014 static void aes_exit(void)
1015 {
1016 	int i;
1017 
1018 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
1019 		if (aes_simd_algs[i])
1020 			simd_skcipher_free(aes_simd_algs[i]);
1021 
1022 	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1023 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1024 }
1025 
1026 static int __init aes_init(void)
1027 {
1028 	struct simd_skcipher_alg *simd;
1029 	const char *basename;
1030 	const char *algname;
1031 	const char *drvname;
1032 	int err;
1033 	int i;
1034 
1035 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1036 	if (err)
1037 		return err;
1038 
1039 	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1040 	if (err)
1041 		goto unregister_ciphers;
1042 
1043 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
1044 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
1045 			continue;
1046 
1047 		algname = aes_algs[i].base.cra_name + 2;
1048 		drvname = aes_algs[i].base.cra_driver_name + 2;
1049 		basename = aes_algs[i].base.cra_driver_name;
1050 		simd = simd_skcipher_create_compat(algname, drvname, basename);
1051 		err = PTR_ERR(simd);
1052 		if (IS_ERR(simd))
1053 			goto unregister_simds;
1054 
1055 		aes_simd_algs[i] = simd;
1056 	}
1057 
1058 	return 0;
1059 
1060 unregister_simds:
1061 	aes_exit();
1062 	return err;
1063 unregister_ciphers:
1064 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1065 	return err;
1066 }
1067 
1068 #ifdef USE_V8_CRYPTO_EXTENSIONS
1069 module_cpu_feature_match(AES, aes_init);
1070 #else
1071 module_init(aes_init);
1072 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1073 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1074 EXPORT_SYMBOL(neon_aes_xts_encrypt);
1075 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1076 #endif
1077 module_exit(aes_exit);
1078