1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/cipher.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18 
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
21 
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)-all");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
26 
27 MODULE_IMPORT_NS(CRYPTO_INTERNAL);
28 
29 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
30 
31 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
32 				  int rounds, int blocks);
33 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
34 				  int rounds, int blocks);
35 
36 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
37 				  int rounds, int blocks, u8 iv[]);
38 
39 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
40 				  int rounds, int blocks, u8 ctr[], u8 final[]);
41 
42 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
43 				  int rounds, int blocks, u8 iv[], int);
44 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
45 				  int rounds, int blocks, u8 iv[], int);
46 
47 struct aesbs_ctx {
48 	int	rounds;
49 	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
50 };
51 
52 struct aesbs_cbc_ctx {
53 	struct aesbs_ctx	key;
54 	struct crypto_skcipher	*enc_tfm;
55 };
56 
57 struct aesbs_xts_ctx {
58 	struct aesbs_ctx	key;
59 	struct crypto_cipher	*cts_tfm;
60 	struct crypto_cipher	*tweak_tfm;
61 };
62 
63 struct aesbs_ctr_ctx {
64 	struct aesbs_ctx	key;		/* must be first member */
65 	struct crypto_aes_ctx	fallback;
66 };
67 
68 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
69 			unsigned int key_len)
70 {
71 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
72 	struct crypto_aes_ctx rk;
73 	int err;
74 
75 	err = aes_expandkey(&rk, in_key, key_len);
76 	if (err)
77 		return err;
78 
79 	ctx->rounds = 6 + key_len / 4;
80 
81 	kernel_neon_begin();
82 	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
83 	kernel_neon_end();
84 
85 	return 0;
86 }
87 
88 static int __ecb_crypt(struct skcipher_request *req,
89 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
90 				  int rounds, int blocks))
91 {
92 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
93 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
94 	struct skcipher_walk walk;
95 	int err;
96 
97 	err = skcipher_walk_virt(&walk, req, false);
98 
99 	while (walk.nbytes >= AES_BLOCK_SIZE) {
100 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
101 
102 		if (walk.nbytes < walk.total)
103 			blocks = round_down(blocks,
104 					    walk.stride / AES_BLOCK_SIZE);
105 
106 		kernel_neon_begin();
107 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
108 		   ctx->rounds, blocks);
109 		kernel_neon_end();
110 		err = skcipher_walk_done(&walk,
111 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
112 	}
113 
114 	return err;
115 }
116 
117 static int ecb_encrypt(struct skcipher_request *req)
118 {
119 	return __ecb_crypt(req, aesbs_ecb_encrypt);
120 }
121 
122 static int ecb_decrypt(struct skcipher_request *req)
123 {
124 	return __ecb_crypt(req, aesbs_ecb_decrypt);
125 }
126 
127 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
128 			    unsigned int key_len)
129 {
130 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
131 	struct crypto_aes_ctx rk;
132 	int err;
133 
134 	err = aes_expandkey(&rk, in_key, key_len);
135 	if (err)
136 		return err;
137 
138 	ctx->key.rounds = 6 + key_len / 4;
139 
140 	kernel_neon_begin();
141 	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
142 	kernel_neon_end();
143 	memzero_explicit(&rk, sizeof(rk));
144 
145 	return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
146 }
147 
148 static int cbc_encrypt(struct skcipher_request *req)
149 {
150 	struct skcipher_request *subreq = skcipher_request_ctx(req);
151 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
152 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
153 
154 	skcipher_request_set_tfm(subreq, ctx->enc_tfm);
155 	skcipher_request_set_callback(subreq,
156 				      skcipher_request_flags(req),
157 				      NULL, NULL);
158 	skcipher_request_set_crypt(subreq, req->src, req->dst,
159 				   req->cryptlen, req->iv);
160 
161 	return crypto_skcipher_encrypt(subreq);
162 }
163 
164 static int cbc_decrypt(struct skcipher_request *req)
165 {
166 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
167 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
168 	struct skcipher_walk walk;
169 	int err;
170 
171 	err = skcipher_walk_virt(&walk, req, false);
172 
173 	while (walk.nbytes >= AES_BLOCK_SIZE) {
174 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
175 
176 		if (walk.nbytes < walk.total)
177 			blocks = round_down(blocks,
178 					    walk.stride / AES_BLOCK_SIZE);
179 
180 		kernel_neon_begin();
181 		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
182 				  ctx->key.rk, ctx->key.rounds, blocks,
183 				  walk.iv);
184 		kernel_neon_end();
185 		err = skcipher_walk_done(&walk,
186 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
187 	}
188 
189 	return err;
190 }
191 
192 static int cbc_init(struct crypto_skcipher *tfm)
193 {
194 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
195 	unsigned int reqsize;
196 
197 	ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
198 					     CRYPTO_ALG_NEED_FALLBACK);
199 	if (IS_ERR(ctx->enc_tfm))
200 		return PTR_ERR(ctx->enc_tfm);
201 
202 	reqsize = sizeof(struct skcipher_request);
203 	reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
204 	crypto_skcipher_set_reqsize(tfm, reqsize);
205 
206 	return 0;
207 }
208 
209 static void cbc_exit(struct crypto_skcipher *tfm)
210 {
211 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
212 
213 	crypto_free_skcipher(ctx->enc_tfm);
214 }
215 
216 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
217 				 unsigned int key_len)
218 {
219 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
220 	int err;
221 
222 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
223 	if (err)
224 		return err;
225 
226 	ctx->key.rounds = 6 + key_len / 4;
227 
228 	kernel_neon_begin();
229 	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
230 	kernel_neon_end();
231 
232 	return 0;
233 }
234 
235 static int ctr_encrypt(struct skcipher_request *req)
236 {
237 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
238 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
239 	struct skcipher_walk walk;
240 	u8 buf[AES_BLOCK_SIZE];
241 	int err;
242 
243 	err = skcipher_walk_virt(&walk, req, false);
244 
245 	while (walk.nbytes > 0) {
246 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
247 		u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
248 
249 		if (walk.nbytes < walk.total) {
250 			blocks = round_down(blocks,
251 					    walk.stride / AES_BLOCK_SIZE);
252 			final = NULL;
253 		}
254 
255 		kernel_neon_begin();
256 		aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
257 				  ctx->rk, ctx->rounds, blocks, walk.iv, final);
258 		kernel_neon_end();
259 
260 		if (final) {
261 			u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
262 			u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
263 
264 			crypto_xor_cpy(dst, src, final,
265 				       walk.total % AES_BLOCK_SIZE);
266 
267 			err = skcipher_walk_done(&walk, 0);
268 			break;
269 		}
270 		err = skcipher_walk_done(&walk,
271 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
272 	}
273 
274 	return err;
275 }
276 
277 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
278 {
279 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
280 	unsigned long flags;
281 
282 	/*
283 	 * Temporarily disable interrupts to avoid races where
284 	 * cachelines are evicted when the CPU is interrupted
285 	 * to do something else.
286 	 */
287 	local_irq_save(flags);
288 	aes_encrypt(&ctx->fallback, dst, src);
289 	local_irq_restore(flags);
290 }
291 
292 static int ctr_encrypt_sync(struct skcipher_request *req)
293 {
294 	if (!crypto_simd_usable())
295 		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
296 
297 	return ctr_encrypt(req);
298 }
299 
300 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
301 			    unsigned int key_len)
302 {
303 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
304 	int err;
305 
306 	err = xts_verify_key(tfm, in_key, key_len);
307 	if (err)
308 		return err;
309 
310 	key_len /= 2;
311 	err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
312 	if (err)
313 		return err;
314 	err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
315 	if (err)
316 		return err;
317 
318 	return aesbs_setkey(tfm, in_key, key_len);
319 }
320 
321 static int xts_init(struct crypto_skcipher *tfm)
322 {
323 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
324 
325 	ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
326 	if (IS_ERR(ctx->cts_tfm))
327 		return PTR_ERR(ctx->cts_tfm);
328 
329 	ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
330 	if (IS_ERR(ctx->tweak_tfm))
331 		crypto_free_cipher(ctx->cts_tfm);
332 
333 	return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
334 }
335 
336 static void xts_exit(struct crypto_skcipher *tfm)
337 {
338 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
339 
340 	crypto_free_cipher(ctx->tweak_tfm);
341 	crypto_free_cipher(ctx->cts_tfm);
342 }
343 
344 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
345 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
346 				  int rounds, int blocks, u8 iv[], int))
347 {
348 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
349 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
350 	int tail = req->cryptlen % AES_BLOCK_SIZE;
351 	struct skcipher_request subreq;
352 	u8 buf[2 * AES_BLOCK_SIZE];
353 	struct skcipher_walk walk;
354 	int err;
355 
356 	if (req->cryptlen < AES_BLOCK_SIZE)
357 		return -EINVAL;
358 
359 	if (unlikely(tail)) {
360 		skcipher_request_set_tfm(&subreq, tfm);
361 		skcipher_request_set_callback(&subreq,
362 					      skcipher_request_flags(req),
363 					      NULL, NULL);
364 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
365 					   req->cryptlen - tail, req->iv);
366 		req = &subreq;
367 	}
368 
369 	err = skcipher_walk_virt(&walk, req, true);
370 	if (err)
371 		return err;
372 
373 	crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
374 
375 	while (walk.nbytes >= AES_BLOCK_SIZE) {
376 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
377 		int reorder_last_tweak = !encrypt && tail > 0;
378 
379 		if (walk.nbytes < walk.total) {
380 			blocks = round_down(blocks,
381 					    walk.stride / AES_BLOCK_SIZE);
382 			reorder_last_tweak = 0;
383 		}
384 
385 		kernel_neon_begin();
386 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
387 		   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
388 		kernel_neon_end();
389 		err = skcipher_walk_done(&walk,
390 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
391 	}
392 
393 	if (err || likely(!tail))
394 		return err;
395 
396 	/* handle ciphertext stealing */
397 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
398 				 AES_BLOCK_SIZE, 0);
399 	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
400 	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
401 
402 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
403 
404 	if (encrypt)
405 		crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
406 	else
407 		crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
408 
409 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
410 
411 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
412 				 AES_BLOCK_SIZE + tail, 1);
413 	return 0;
414 }
415 
416 static int xts_encrypt(struct skcipher_request *req)
417 {
418 	return __xts_crypt(req, true, aesbs_xts_encrypt);
419 }
420 
421 static int xts_decrypt(struct skcipher_request *req)
422 {
423 	return __xts_crypt(req, false, aesbs_xts_decrypt);
424 }
425 
426 static struct skcipher_alg aes_algs[] = { {
427 	.base.cra_name		= "__ecb(aes)",
428 	.base.cra_driver_name	= "__ecb-aes-neonbs",
429 	.base.cra_priority	= 250,
430 	.base.cra_blocksize	= AES_BLOCK_SIZE,
431 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
432 	.base.cra_module	= THIS_MODULE,
433 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
434 
435 	.min_keysize		= AES_MIN_KEY_SIZE,
436 	.max_keysize		= AES_MAX_KEY_SIZE,
437 	.walksize		= 8 * AES_BLOCK_SIZE,
438 	.setkey			= aesbs_setkey,
439 	.encrypt		= ecb_encrypt,
440 	.decrypt		= ecb_decrypt,
441 }, {
442 	.base.cra_name		= "__cbc(aes)",
443 	.base.cra_driver_name	= "__cbc-aes-neonbs",
444 	.base.cra_priority	= 250,
445 	.base.cra_blocksize	= AES_BLOCK_SIZE,
446 	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
447 	.base.cra_module	= THIS_MODULE,
448 	.base.cra_flags		= CRYPTO_ALG_INTERNAL |
449 				  CRYPTO_ALG_NEED_FALLBACK,
450 
451 	.min_keysize		= AES_MIN_KEY_SIZE,
452 	.max_keysize		= AES_MAX_KEY_SIZE,
453 	.walksize		= 8 * AES_BLOCK_SIZE,
454 	.ivsize			= AES_BLOCK_SIZE,
455 	.setkey			= aesbs_cbc_setkey,
456 	.encrypt		= cbc_encrypt,
457 	.decrypt		= cbc_decrypt,
458 	.init			= cbc_init,
459 	.exit			= cbc_exit,
460 }, {
461 	.base.cra_name		= "__ctr(aes)",
462 	.base.cra_driver_name	= "__ctr-aes-neonbs",
463 	.base.cra_priority	= 250,
464 	.base.cra_blocksize	= 1,
465 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
466 	.base.cra_module	= THIS_MODULE,
467 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
468 
469 	.min_keysize		= AES_MIN_KEY_SIZE,
470 	.max_keysize		= AES_MAX_KEY_SIZE,
471 	.chunksize		= AES_BLOCK_SIZE,
472 	.walksize		= 8 * AES_BLOCK_SIZE,
473 	.ivsize			= AES_BLOCK_SIZE,
474 	.setkey			= aesbs_setkey,
475 	.encrypt		= ctr_encrypt,
476 	.decrypt		= ctr_encrypt,
477 }, {
478 	.base.cra_name		= "ctr(aes)",
479 	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
480 	.base.cra_priority	= 250 - 1,
481 	.base.cra_blocksize	= 1,
482 	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
483 	.base.cra_module	= THIS_MODULE,
484 
485 	.min_keysize		= AES_MIN_KEY_SIZE,
486 	.max_keysize		= AES_MAX_KEY_SIZE,
487 	.chunksize		= AES_BLOCK_SIZE,
488 	.walksize		= 8 * AES_BLOCK_SIZE,
489 	.ivsize			= AES_BLOCK_SIZE,
490 	.setkey			= aesbs_ctr_setkey_sync,
491 	.encrypt		= ctr_encrypt_sync,
492 	.decrypt		= ctr_encrypt_sync,
493 }, {
494 	.base.cra_name		= "__xts(aes)",
495 	.base.cra_driver_name	= "__xts-aes-neonbs",
496 	.base.cra_priority	= 250,
497 	.base.cra_blocksize	= AES_BLOCK_SIZE,
498 	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
499 	.base.cra_module	= THIS_MODULE,
500 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
501 
502 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
503 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
504 	.walksize		= 8 * AES_BLOCK_SIZE,
505 	.ivsize			= AES_BLOCK_SIZE,
506 	.setkey			= aesbs_xts_setkey,
507 	.encrypt		= xts_encrypt,
508 	.decrypt		= xts_decrypt,
509 	.init			= xts_init,
510 	.exit			= xts_exit,
511 } };
512 
513 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
514 
515 static void aes_exit(void)
516 {
517 	int i;
518 
519 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
520 		if (aes_simd_algs[i])
521 			simd_skcipher_free(aes_simd_algs[i]);
522 
523 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
524 }
525 
526 static int __init aes_init(void)
527 {
528 	struct simd_skcipher_alg *simd;
529 	const char *basename;
530 	const char *algname;
531 	const char *drvname;
532 	int err;
533 	int i;
534 
535 	if (!(elf_hwcap & HWCAP_NEON))
536 		return -ENODEV;
537 
538 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
539 	if (err)
540 		return err;
541 
542 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
543 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
544 			continue;
545 
546 		algname = aes_algs[i].base.cra_name + 2;
547 		drvname = aes_algs[i].base.cra_driver_name + 2;
548 		basename = aes_algs[i].base.cra_driver_name;
549 		simd = simd_skcipher_create_compat(algname, drvname, basename);
550 		err = PTR_ERR(simd);
551 		if (IS_ERR(simd))
552 			goto unregister_simds;
553 
554 		aes_simd_algs[i] = simd;
555 	}
556 	return 0;
557 
558 unregister_simds:
559 	aes_exit();
560 	return err;
561 }
562 
563 late_initcall(aes_init);
564 module_exit(aes_exit);
565