xref: /openbmc/linux/arch/arm64/crypto/aes-glue.c (revision 04eb94d526423ff082efce61f4f26b0369d0bfdd)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/internal/hash.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <linux/module.h>
17 #include <linux/cpufeature.h>
18 #include <crypto/xts.h>
19 
20 #include "aes-ce-setkey.h"
21 #include "aes-ctr-fallback.h"
22 
23 #ifdef USE_V8_CRYPTO_EXTENSIONS
24 #define MODE			"ce"
25 #define PRIO			300
26 #define aes_setkey		ce_aes_setkey
27 #define aes_expandkey		ce_aes_expandkey
28 #define aes_ecb_encrypt		ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
34 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
35 #define aes_xts_encrypt		ce_aes_xts_encrypt
36 #define aes_xts_decrypt		ce_aes_xts_decrypt
37 #define aes_mac_update		ce_aes_mac_update
38 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
39 #else
40 #define MODE			"neon"
41 #define PRIO			200
42 #define aes_setkey		crypto_aes_set_key
43 #define aes_expandkey		crypto_aes_expand_key
44 #define aes_ecb_encrypt		neon_aes_ecb_encrypt
45 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
46 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
47 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
48 #define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
49 #define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
50 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
51 #define aes_xts_encrypt		neon_aes_xts_encrypt
52 #define aes_xts_decrypt		neon_aes_xts_decrypt
53 #define aes_mac_update		neon_aes_mac_update
54 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
55 MODULE_ALIAS_CRYPTO("ecb(aes)");
56 MODULE_ALIAS_CRYPTO("cbc(aes)");
57 MODULE_ALIAS_CRYPTO("ctr(aes)");
58 MODULE_ALIAS_CRYPTO("xts(aes)");
59 MODULE_ALIAS_CRYPTO("cmac(aes)");
60 MODULE_ALIAS_CRYPTO("xcbc(aes)");
61 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
62 #endif
63 
64 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
65 MODULE_LICENSE("GPL v2");
66 
67 /* defined in aes-modes.S */
68 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
69 				int rounds, int blocks);
70 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
71 				int rounds, int blocks);
72 
73 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
74 				int rounds, int blocks, u8 iv[]);
75 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
76 				int rounds, int blocks, u8 iv[]);
77 
78 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
79 				int rounds, int bytes, u8 const iv[]);
80 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
81 				int rounds, int bytes, u8 const iv[]);
82 
83 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
84 				int rounds, int blocks, u8 ctr[]);
85 
86 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
87 				int rounds, int blocks, u32 const rk2[], u8 iv[],
88 				int first);
89 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
90 				int rounds, int blocks, u32 const rk2[], u8 iv[],
91 				int first);
92 
93 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
94 			       int blocks, u8 dg[], int enc_before,
95 			       int enc_after);
96 
97 struct cts_cbc_req_ctx {
98 	struct scatterlist sg_src[2];
99 	struct scatterlist sg_dst[2];
100 	struct skcipher_request subreq;
101 };
102 
103 struct crypto_aes_xts_ctx {
104 	struct crypto_aes_ctx key1;
105 	struct crypto_aes_ctx __aligned(8) key2;
106 };
107 
108 struct mac_tfm_ctx {
109 	struct crypto_aes_ctx key;
110 	u8 __aligned(8) consts[];
111 };
112 
113 struct mac_desc_ctx {
114 	unsigned int len;
115 	u8 dg[AES_BLOCK_SIZE];
116 };
117 
118 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
119 			       unsigned int key_len)
120 {
121 	return aes_setkey(crypto_skcipher_tfm(tfm), in_key, key_len);
122 }
123 
124 static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
125 		       unsigned int key_len)
126 {
127 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
128 	int ret;
129 
130 	ret = xts_verify_key(tfm, in_key, key_len);
131 	if (ret)
132 		return ret;
133 
134 	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
135 	if (!ret)
136 		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
137 				    key_len / 2);
138 	if (!ret)
139 		return 0;
140 
141 	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
142 	return -EINVAL;
143 }
144 
145 static int ecb_encrypt(struct skcipher_request *req)
146 {
147 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
148 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
149 	int err, rounds = 6 + ctx->key_length / 4;
150 	struct skcipher_walk walk;
151 	unsigned int blocks;
152 
153 	err = skcipher_walk_virt(&walk, req, false);
154 
155 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
156 		kernel_neon_begin();
157 		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
158 				ctx->key_enc, rounds, blocks);
159 		kernel_neon_end();
160 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
161 	}
162 	return err;
163 }
164 
165 static int ecb_decrypt(struct skcipher_request *req)
166 {
167 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
168 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
169 	int err, rounds = 6 + ctx->key_length / 4;
170 	struct skcipher_walk walk;
171 	unsigned int blocks;
172 
173 	err = skcipher_walk_virt(&walk, req, false);
174 
175 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
176 		kernel_neon_begin();
177 		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
178 				ctx->key_dec, rounds, blocks);
179 		kernel_neon_end();
180 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
181 	}
182 	return err;
183 }
184 
185 static int cbc_encrypt(struct skcipher_request *req)
186 {
187 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
188 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
189 	int err, rounds = 6 + ctx->key_length / 4;
190 	struct skcipher_walk walk;
191 	unsigned int blocks;
192 
193 	err = skcipher_walk_virt(&walk, req, false);
194 
195 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
196 		kernel_neon_begin();
197 		aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
198 				ctx->key_enc, rounds, blocks, walk.iv);
199 		kernel_neon_end();
200 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
201 	}
202 	return err;
203 }
204 
205 static int cbc_decrypt(struct skcipher_request *req)
206 {
207 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
208 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
209 	int err, rounds = 6 + ctx->key_length / 4;
210 	struct skcipher_walk walk;
211 	unsigned int blocks;
212 
213 	err = skcipher_walk_virt(&walk, req, false);
214 
215 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
216 		kernel_neon_begin();
217 		aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
218 				ctx->key_dec, rounds, blocks, walk.iv);
219 		kernel_neon_end();
220 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
221 	}
222 	return err;
223 }
224 
225 static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
226 {
227 	crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
228 	return 0;
229 }
230 
231 static int cts_cbc_encrypt(struct skcipher_request *req)
232 {
233 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
234 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
235 	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
236 	int err, rounds = 6 + ctx->key_length / 4;
237 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
238 	struct scatterlist *src = req->src, *dst = req->dst;
239 	struct skcipher_walk walk;
240 
241 	skcipher_request_set_tfm(&rctx->subreq, tfm);
242 
243 	if (req->cryptlen <= AES_BLOCK_SIZE) {
244 		if (req->cryptlen < AES_BLOCK_SIZE)
245 			return -EINVAL;
246 		cbc_blocks = 1;
247 	}
248 
249 	if (cbc_blocks > 0) {
250 		unsigned int blocks;
251 
252 		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
253 					   cbc_blocks * AES_BLOCK_SIZE,
254 					   req->iv);
255 
256 		err = skcipher_walk_virt(&walk, &rctx->subreq, false);
257 
258 		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
259 			kernel_neon_begin();
260 			aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
261 					ctx->key_enc, rounds, blocks, walk.iv);
262 			kernel_neon_end();
263 			err = skcipher_walk_done(&walk,
264 						 walk.nbytes % AES_BLOCK_SIZE);
265 		}
266 		if (err)
267 			return err;
268 
269 		if (req->cryptlen == AES_BLOCK_SIZE)
270 			return 0;
271 
272 		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
273 					     rctx->subreq.cryptlen);
274 		if (req->dst != req->src)
275 			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
276 					       rctx->subreq.cryptlen);
277 	}
278 
279 	/* handle ciphertext stealing */
280 	skcipher_request_set_crypt(&rctx->subreq, src, dst,
281 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
282 				   req->iv);
283 
284 	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
285 	if (err)
286 		return err;
287 
288 	kernel_neon_begin();
289 	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
290 			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
291 	kernel_neon_end();
292 
293 	return skcipher_walk_done(&walk, 0);
294 }
295 
296 static int cts_cbc_decrypt(struct skcipher_request *req)
297 {
298 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
299 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
300 	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
301 	int err, rounds = 6 + ctx->key_length / 4;
302 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
303 	struct scatterlist *src = req->src, *dst = req->dst;
304 	struct skcipher_walk walk;
305 
306 	skcipher_request_set_tfm(&rctx->subreq, tfm);
307 
308 	if (req->cryptlen <= AES_BLOCK_SIZE) {
309 		if (req->cryptlen < AES_BLOCK_SIZE)
310 			return -EINVAL;
311 		cbc_blocks = 1;
312 	}
313 
314 	if (cbc_blocks > 0) {
315 		unsigned int blocks;
316 
317 		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
318 					   cbc_blocks * AES_BLOCK_SIZE,
319 					   req->iv);
320 
321 		err = skcipher_walk_virt(&walk, &rctx->subreq, false);
322 
323 		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
324 			kernel_neon_begin();
325 			aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
326 					ctx->key_dec, rounds, blocks, walk.iv);
327 			kernel_neon_end();
328 			err = skcipher_walk_done(&walk,
329 						 walk.nbytes % AES_BLOCK_SIZE);
330 		}
331 		if (err)
332 			return err;
333 
334 		if (req->cryptlen == AES_BLOCK_SIZE)
335 			return 0;
336 
337 		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
338 					     rctx->subreq.cryptlen);
339 		if (req->dst != req->src)
340 			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
341 					       rctx->subreq.cryptlen);
342 	}
343 
344 	/* handle ciphertext stealing */
345 	skcipher_request_set_crypt(&rctx->subreq, src, dst,
346 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
347 				   req->iv);
348 
349 	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
350 	if (err)
351 		return err;
352 
353 	kernel_neon_begin();
354 	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
355 			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
356 	kernel_neon_end();
357 
358 	return skcipher_walk_done(&walk, 0);
359 }
360 
361 static int ctr_encrypt(struct skcipher_request *req)
362 {
363 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
364 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
365 	int err, rounds = 6 + ctx->key_length / 4;
366 	struct skcipher_walk walk;
367 	int blocks;
368 
369 	err = skcipher_walk_virt(&walk, req, false);
370 
371 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
372 		kernel_neon_begin();
373 		aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
374 				ctx->key_enc, rounds, blocks, walk.iv);
375 		kernel_neon_end();
376 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
377 	}
378 	if (walk.nbytes) {
379 		u8 __aligned(8) tail[AES_BLOCK_SIZE];
380 		unsigned int nbytes = walk.nbytes;
381 		u8 *tdst = walk.dst.virt.addr;
382 		u8 *tsrc = walk.src.virt.addr;
383 
384 		/*
385 		 * Tell aes_ctr_encrypt() to process a tail block.
386 		 */
387 		blocks = -1;
388 
389 		kernel_neon_begin();
390 		aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
391 				blocks, walk.iv);
392 		kernel_neon_end();
393 		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
394 		err = skcipher_walk_done(&walk, 0);
395 	}
396 
397 	return err;
398 }
399 
400 static int ctr_encrypt_sync(struct skcipher_request *req)
401 {
402 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
403 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
404 
405 	if (!crypto_simd_usable())
406 		return aes_ctr_encrypt_fallback(ctx, req);
407 
408 	return ctr_encrypt(req);
409 }
410 
411 static int xts_encrypt(struct skcipher_request *req)
412 {
413 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
414 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
415 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
416 	struct skcipher_walk walk;
417 	unsigned int blocks;
418 
419 	err = skcipher_walk_virt(&walk, req, false);
420 
421 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
422 		kernel_neon_begin();
423 		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
424 				ctx->key1.key_enc, rounds, blocks,
425 				ctx->key2.key_enc, walk.iv, first);
426 		kernel_neon_end();
427 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
428 	}
429 
430 	return err;
431 }
432 
433 static int xts_decrypt(struct skcipher_request *req)
434 {
435 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
436 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
437 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
438 	struct skcipher_walk walk;
439 	unsigned int blocks;
440 
441 	err = skcipher_walk_virt(&walk, req, false);
442 
443 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
444 		kernel_neon_begin();
445 		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
446 				ctx->key1.key_dec, rounds, blocks,
447 				ctx->key2.key_enc, walk.iv, first);
448 		kernel_neon_end();
449 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
450 	}
451 
452 	return err;
453 }
454 
455 static struct skcipher_alg aes_algs[] = { {
456 	.base = {
457 		.cra_name		= "__ecb(aes)",
458 		.cra_driver_name	= "__ecb-aes-" MODE,
459 		.cra_priority		= PRIO,
460 		.cra_flags		= CRYPTO_ALG_INTERNAL,
461 		.cra_blocksize		= AES_BLOCK_SIZE,
462 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
463 		.cra_module		= THIS_MODULE,
464 	},
465 	.min_keysize	= AES_MIN_KEY_SIZE,
466 	.max_keysize	= AES_MAX_KEY_SIZE,
467 	.setkey		= skcipher_aes_setkey,
468 	.encrypt	= ecb_encrypt,
469 	.decrypt	= ecb_decrypt,
470 }, {
471 	.base = {
472 		.cra_name		= "__cbc(aes)",
473 		.cra_driver_name	= "__cbc-aes-" MODE,
474 		.cra_priority		= PRIO,
475 		.cra_flags		= CRYPTO_ALG_INTERNAL,
476 		.cra_blocksize		= AES_BLOCK_SIZE,
477 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
478 		.cra_module		= THIS_MODULE,
479 	},
480 	.min_keysize	= AES_MIN_KEY_SIZE,
481 	.max_keysize	= AES_MAX_KEY_SIZE,
482 	.ivsize		= AES_BLOCK_SIZE,
483 	.setkey		= skcipher_aes_setkey,
484 	.encrypt	= cbc_encrypt,
485 	.decrypt	= cbc_decrypt,
486 }, {
487 	.base = {
488 		.cra_name		= "__cts(cbc(aes))",
489 		.cra_driver_name	= "__cts-cbc-aes-" MODE,
490 		.cra_priority		= PRIO,
491 		.cra_flags		= CRYPTO_ALG_INTERNAL,
492 		.cra_blocksize		= AES_BLOCK_SIZE,
493 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
494 		.cra_module		= THIS_MODULE,
495 	},
496 	.min_keysize	= AES_MIN_KEY_SIZE,
497 	.max_keysize	= AES_MAX_KEY_SIZE,
498 	.ivsize		= AES_BLOCK_SIZE,
499 	.walksize	= 2 * AES_BLOCK_SIZE,
500 	.setkey		= skcipher_aes_setkey,
501 	.encrypt	= cts_cbc_encrypt,
502 	.decrypt	= cts_cbc_decrypt,
503 	.init		= cts_cbc_init_tfm,
504 }, {
505 	.base = {
506 		.cra_name		= "__ctr(aes)",
507 		.cra_driver_name	= "__ctr-aes-" MODE,
508 		.cra_priority		= PRIO,
509 		.cra_flags		= CRYPTO_ALG_INTERNAL,
510 		.cra_blocksize		= 1,
511 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
512 		.cra_module		= THIS_MODULE,
513 	},
514 	.min_keysize	= AES_MIN_KEY_SIZE,
515 	.max_keysize	= AES_MAX_KEY_SIZE,
516 	.ivsize		= AES_BLOCK_SIZE,
517 	.chunksize	= AES_BLOCK_SIZE,
518 	.setkey		= skcipher_aes_setkey,
519 	.encrypt	= ctr_encrypt,
520 	.decrypt	= ctr_encrypt,
521 }, {
522 	.base = {
523 		.cra_name		= "ctr(aes)",
524 		.cra_driver_name	= "ctr-aes-" MODE,
525 		.cra_priority		= PRIO - 1,
526 		.cra_blocksize		= 1,
527 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
528 		.cra_module		= THIS_MODULE,
529 	},
530 	.min_keysize	= AES_MIN_KEY_SIZE,
531 	.max_keysize	= AES_MAX_KEY_SIZE,
532 	.ivsize		= AES_BLOCK_SIZE,
533 	.chunksize	= AES_BLOCK_SIZE,
534 	.setkey		= skcipher_aes_setkey,
535 	.encrypt	= ctr_encrypt_sync,
536 	.decrypt	= ctr_encrypt_sync,
537 }, {
538 	.base = {
539 		.cra_name		= "__xts(aes)",
540 		.cra_driver_name	= "__xts-aes-" MODE,
541 		.cra_priority		= PRIO,
542 		.cra_flags		= CRYPTO_ALG_INTERNAL,
543 		.cra_blocksize		= AES_BLOCK_SIZE,
544 		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
545 		.cra_module		= THIS_MODULE,
546 	},
547 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
548 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
549 	.ivsize		= AES_BLOCK_SIZE,
550 	.setkey		= xts_set_key,
551 	.encrypt	= xts_encrypt,
552 	.decrypt	= xts_decrypt,
553 } };
554 
555 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
556 			 unsigned int key_len)
557 {
558 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
559 	int err;
560 
561 	err = aes_expandkey(&ctx->key, in_key, key_len);
562 	if (err)
563 		crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
564 
565 	return err;
566 }
567 
568 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
569 {
570 	u64 a = be64_to_cpu(x->a);
571 	u64 b = be64_to_cpu(x->b);
572 
573 	y->a = cpu_to_be64((a << 1) | (b >> 63));
574 	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
575 }
576 
577 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
578 		       unsigned int key_len)
579 {
580 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
581 	be128 *consts = (be128 *)ctx->consts;
582 	int rounds = 6 + key_len / 4;
583 	int err;
584 
585 	err = cbcmac_setkey(tfm, in_key, key_len);
586 	if (err)
587 		return err;
588 
589 	/* encrypt the zero vector */
590 	kernel_neon_begin();
591 	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
592 			rounds, 1);
593 	kernel_neon_end();
594 
595 	cmac_gf128_mul_by_x(consts, consts);
596 	cmac_gf128_mul_by_x(consts + 1, consts);
597 
598 	return 0;
599 }
600 
601 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
602 		       unsigned int key_len)
603 {
604 	static u8 const ks[3][AES_BLOCK_SIZE] = {
605 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
606 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
607 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
608 	};
609 
610 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
611 	int rounds = 6 + key_len / 4;
612 	u8 key[AES_BLOCK_SIZE];
613 	int err;
614 
615 	err = cbcmac_setkey(tfm, in_key, key_len);
616 	if (err)
617 		return err;
618 
619 	kernel_neon_begin();
620 	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
621 	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
622 	kernel_neon_end();
623 
624 	return cbcmac_setkey(tfm, key, sizeof(key));
625 }
626 
627 static int mac_init(struct shash_desc *desc)
628 {
629 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
630 
631 	memset(ctx->dg, 0, AES_BLOCK_SIZE);
632 	ctx->len = 0;
633 
634 	return 0;
635 }
636 
637 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
638 			  u8 dg[], int enc_before, int enc_after)
639 {
640 	int rounds = 6 + ctx->key_length / 4;
641 
642 	if (crypto_simd_usable()) {
643 		kernel_neon_begin();
644 		aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
645 			       enc_after);
646 		kernel_neon_end();
647 	} else {
648 		if (enc_before)
649 			__aes_arm64_encrypt(ctx->key_enc, dg, dg, rounds);
650 
651 		while (blocks--) {
652 			crypto_xor(dg, in, AES_BLOCK_SIZE);
653 			in += AES_BLOCK_SIZE;
654 
655 			if (blocks || enc_after)
656 				__aes_arm64_encrypt(ctx->key_enc, dg, dg,
657 						    rounds);
658 		}
659 	}
660 }
661 
662 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
663 {
664 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
665 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
666 
667 	while (len > 0) {
668 		unsigned int l;
669 
670 		if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
671 		    (ctx->len + len) > AES_BLOCK_SIZE) {
672 
673 			int blocks = len / AES_BLOCK_SIZE;
674 
675 			len %= AES_BLOCK_SIZE;
676 
677 			mac_do_update(&tctx->key, p, blocks, ctx->dg,
678 				      (ctx->len != 0), (len != 0));
679 
680 			p += blocks * AES_BLOCK_SIZE;
681 
682 			if (!len) {
683 				ctx->len = AES_BLOCK_SIZE;
684 				break;
685 			}
686 			ctx->len = 0;
687 		}
688 
689 		l = min(len, AES_BLOCK_SIZE - ctx->len);
690 
691 		if (l <= AES_BLOCK_SIZE) {
692 			crypto_xor(ctx->dg + ctx->len, p, l);
693 			ctx->len += l;
694 			len -= l;
695 			p += l;
696 		}
697 	}
698 
699 	return 0;
700 }
701 
702 static int cbcmac_final(struct shash_desc *desc, u8 *out)
703 {
704 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
705 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
706 
707 	mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
708 
709 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
710 
711 	return 0;
712 }
713 
714 static int cmac_final(struct shash_desc *desc, u8 *out)
715 {
716 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
717 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
718 	u8 *consts = tctx->consts;
719 
720 	if (ctx->len != AES_BLOCK_SIZE) {
721 		ctx->dg[ctx->len] ^= 0x80;
722 		consts += AES_BLOCK_SIZE;
723 	}
724 
725 	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
726 
727 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
728 
729 	return 0;
730 }
731 
732 static struct shash_alg mac_algs[] = { {
733 	.base.cra_name		= "cmac(aes)",
734 	.base.cra_driver_name	= "cmac-aes-" MODE,
735 	.base.cra_priority	= PRIO,
736 	.base.cra_blocksize	= AES_BLOCK_SIZE,
737 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
738 				  2 * AES_BLOCK_SIZE,
739 	.base.cra_module	= THIS_MODULE,
740 
741 	.digestsize		= AES_BLOCK_SIZE,
742 	.init			= mac_init,
743 	.update			= mac_update,
744 	.final			= cmac_final,
745 	.setkey			= cmac_setkey,
746 	.descsize		= sizeof(struct mac_desc_ctx),
747 }, {
748 	.base.cra_name		= "xcbc(aes)",
749 	.base.cra_driver_name	= "xcbc-aes-" MODE,
750 	.base.cra_priority	= PRIO,
751 	.base.cra_blocksize	= AES_BLOCK_SIZE,
752 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
753 				  2 * AES_BLOCK_SIZE,
754 	.base.cra_module	= THIS_MODULE,
755 
756 	.digestsize		= AES_BLOCK_SIZE,
757 	.init			= mac_init,
758 	.update			= mac_update,
759 	.final			= cmac_final,
760 	.setkey			= xcbc_setkey,
761 	.descsize		= sizeof(struct mac_desc_ctx),
762 }, {
763 	.base.cra_name		= "cbcmac(aes)",
764 	.base.cra_driver_name	= "cbcmac-aes-" MODE,
765 	.base.cra_priority	= PRIO,
766 	.base.cra_blocksize	= 1,
767 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
768 	.base.cra_module	= THIS_MODULE,
769 
770 	.digestsize		= AES_BLOCK_SIZE,
771 	.init			= mac_init,
772 	.update			= mac_update,
773 	.final			= cbcmac_final,
774 	.setkey			= cbcmac_setkey,
775 	.descsize		= sizeof(struct mac_desc_ctx),
776 } };
777 
778 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
779 
780 static void aes_exit(void)
781 {
782 	int i;
783 
784 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
785 		if (aes_simd_algs[i])
786 			simd_skcipher_free(aes_simd_algs[i]);
787 
788 	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
789 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
790 }
791 
792 static int __init aes_init(void)
793 {
794 	struct simd_skcipher_alg *simd;
795 	const char *basename;
796 	const char *algname;
797 	const char *drvname;
798 	int err;
799 	int i;
800 
801 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
802 	if (err)
803 		return err;
804 
805 	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
806 	if (err)
807 		goto unregister_ciphers;
808 
809 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
810 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
811 			continue;
812 
813 		algname = aes_algs[i].base.cra_name + 2;
814 		drvname = aes_algs[i].base.cra_driver_name + 2;
815 		basename = aes_algs[i].base.cra_driver_name;
816 		simd = simd_skcipher_create_compat(algname, drvname, basename);
817 		err = PTR_ERR(simd);
818 		if (IS_ERR(simd))
819 			goto unregister_simds;
820 
821 		aes_simd_algs[i] = simd;
822 	}
823 
824 	return 0;
825 
826 unregister_simds:
827 	aes_exit();
828 	return err;
829 unregister_ciphers:
830 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
831 	return err;
832 }
833 
834 #ifdef USE_V8_CRYPTO_EXTENSIONS
835 module_cpu_feature_match(AES, aes_init);
836 #else
837 module_init(aes_init);
838 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
839 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
840 #endif
841 module_exit(aes_exit);
842