1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/cbc.h>
12 #include <crypto/ctr.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18 
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
21 
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
26 
27 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
28 
29 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
30 				  int rounds, int blocks);
31 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
32 				  int rounds, int blocks);
33 
34 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
35 				  int rounds, int blocks, u8 iv[]);
36 
37 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
38 				  int rounds, int blocks, u8 ctr[], u8 final[]);
39 
40 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
41 				  int rounds, int blocks, u8 iv[], int);
42 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
43 				  int rounds, int blocks, u8 iv[], int);
44 
45 struct aesbs_ctx {
46 	int	rounds;
47 	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
48 };
49 
50 struct aesbs_cbc_ctx {
51 	struct aesbs_ctx	key;
52 	struct crypto_cipher	*enc_tfm;
53 };
54 
55 struct aesbs_xts_ctx {
56 	struct aesbs_ctx	key;
57 	struct crypto_cipher	*cts_tfm;
58 	struct crypto_cipher	*tweak_tfm;
59 };
60 
61 struct aesbs_ctr_ctx {
62 	struct aesbs_ctx	key;		/* must be first member */
63 	struct crypto_aes_ctx	fallback;
64 };
65 
66 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
67 			unsigned int key_len)
68 {
69 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
70 	struct crypto_aes_ctx rk;
71 	int err;
72 
73 	err = aes_expandkey(&rk, in_key, key_len);
74 	if (err)
75 		return err;
76 
77 	ctx->rounds = 6 + key_len / 4;
78 
79 	kernel_neon_begin();
80 	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
81 	kernel_neon_end();
82 
83 	return 0;
84 }
85 
86 static int __ecb_crypt(struct skcipher_request *req,
87 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
88 				  int rounds, int blocks))
89 {
90 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
91 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
92 	struct skcipher_walk walk;
93 	int err;
94 
95 	err = skcipher_walk_virt(&walk, req, false);
96 
97 	while (walk.nbytes >= AES_BLOCK_SIZE) {
98 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
99 
100 		if (walk.nbytes < walk.total)
101 			blocks = round_down(blocks,
102 					    walk.stride / AES_BLOCK_SIZE);
103 
104 		kernel_neon_begin();
105 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
106 		   ctx->rounds, blocks);
107 		kernel_neon_end();
108 		err = skcipher_walk_done(&walk,
109 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
110 	}
111 
112 	return err;
113 }
114 
115 static int ecb_encrypt(struct skcipher_request *req)
116 {
117 	return __ecb_crypt(req, aesbs_ecb_encrypt);
118 }
119 
120 static int ecb_decrypt(struct skcipher_request *req)
121 {
122 	return __ecb_crypt(req, aesbs_ecb_decrypt);
123 }
124 
125 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
126 			    unsigned int key_len)
127 {
128 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
129 	struct crypto_aes_ctx rk;
130 	int err;
131 
132 	err = aes_expandkey(&rk, in_key, key_len);
133 	if (err)
134 		return err;
135 
136 	ctx->key.rounds = 6 + key_len / 4;
137 
138 	kernel_neon_begin();
139 	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
140 	kernel_neon_end();
141 
142 	return crypto_cipher_setkey(ctx->enc_tfm, in_key, key_len);
143 }
144 
145 static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
146 {
147 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
148 
149 	crypto_cipher_encrypt_one(ctx->enc_tfm, dst, src);
150 }
151 
152 static int cbc_encrypt(struct skcipher_request *req)
153 {
154 	return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
155 }
156 
157 static int cbc_decrypt(struct skcipher_request *req)
158 {
159 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
160 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
161 	struct skcipher_walk walk;
162 	int err;
163 
164 	err = skcipher_walk_virt(&walk, req, false);
165 
166 	while (walk.nbytes >= AES_BLOCK_SIZE) {
167 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
168 
169 		if (walk.nbytes < walk.total)
170 			blocks = round_down(blocks,
171 					    walk.stride / AES_BLOCK_SIZE);
172 
173 		kernel_neon_begin();
174 		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
175 				  ctx->key.rk, ctx->key.rounds, blocks,
176 				  walk.iv);
177 		kernel_neon_end();
178 		err = skcipher_walk_done(&walk,
179 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
180 	}
181 
182 	return err;
183 }
184 
185 static int cbc_init(struct crypto_tfm *tfm)
186 {
187 	struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
188 
189 	ctx->enc_tfm = crypto_alloc_cipher("aes", 0, 0);
190 
191 	return PTR_ERR_OR_ZERO(ctx->enc_tfm);
192 }
193 
194 static void cbc_exit(struct crypto_tfm *tfm)
195 {
196 	struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
197 
198 	crypto_free_cipher(ctx->enc_tfm);
199 }
200 
201 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
202 				 unsigned int key_len)
203 {
204 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
205 	int err;
206 
207 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
208 	if (err)
209 		return err;
210 
211 	ctx->key.rounds = 6 + key_len / 4;
212 
213 	kernel_neon_begin();
214 	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
215 	kernel_neon_end();
216 
217 	return 0;
218 }
219 
220 static int ctr_encrypt(struct skcipher_request *req)
221 {
222 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
223 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
224 	struct skcipher_walk walk;
225 	u8 buf[AES_BLOCK_SIZE];
226 	int err;
227 
228 	err = skcipher_walk_virt(&walk, req, false);
229 
230 	while (walk.nbytes > 0) {
231 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
232 		u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
233 
234 		if (walk.nbytes < walk.total) {
235 			blocks = round_down(blocks,
236 					    walk.stride / AES_BLOCK_SIZE);
237 			final = NULL;
238 		}
239 
240 		kernel_neon_begin();
241 		aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
242 				  ctx->rk, ctx->rounds, blocks, walk.iv, final);
243 		kernel_neon_end();
244 
245 		if (final) {
246 			u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
247 			u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
248 
249 			crypto_xor_cpy(dst, src, final,
250 				       walk.total % AES_BLOCK_SIZE);
251 
252 			err = skcipher_walk_done(&walk, 0);
253 			break;
254 		}
255 		err = skcipher_walk_done(&walk,
256 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
257 	}
258 
259 	return err;
260 }
261 
262 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
263 {
264 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
265 	unsigned long flags;
266 
267 	/*
268 	 * Temporarily disable interrupts to avoid races where
269 	 * cachelines are evicted when the CPU is interrupted
270 	 * to do something else.
271 	 */
272 	local_irq_save(flags);
273 	aes_encrypt(&ctx->fallback, dst, src);
274 	local_irq_restore(flags);
275 }
276 
277 static int ctr_encrypt_sync(struct skcipher_request *req)
278 {
279 	if (!crypto_simd_usable())
280 		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
281 
282 	return ctr_encrypt(req);
283 }
284 
285 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
286 			    unsigned int key_len)
287 {
288 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
289 	int err;
290 
291 	err = xts_verify_key(tfm, in_key, key_len);
292 	if (err)
293 		return err;
294 
295 	key_len /= 2;
296 	err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
297 	if (err)
298 		return err;
299 	err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
300 	if (err)
301 		return err;
302 
303 	return aesbs_setkey(tfm, in_key, key_len);
304 }
305 
306 static int xts_init(struct crypto_tfm *tfm)
307 {
308 	struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
309 
310 	ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
311 	if (IS_ERR(ctx->cts_tfm))
312 		return PTR_ERR(ctx->cts_tfm);
313 
314 	ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
315 	if (IS_ERR(ctx->tweak_tfm))
316 		crypto_free_cipher(ctx->cts_tfm);
317 
318 	return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
319 }
320 
321 static void xts_exit(struct crypto_tfm *tfm)
322 {
323 	struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
324 
325 	crypto_free_cipher(ctx->tweak_tfm);
326 	crypto_free_cipher(ctx->cts_tfm);
327 }
328 
329 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
330 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
331 				  int rounds, int blocks, u8 iv[], int))
332 {
333 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
334 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
335 	int tail = req->cryptlen % AES_BLOCK_SIZE;
336 	struct skcipher_request subreq;
337 	u8 buf[2 * AES_BLOCK_SIZE];
338 	struct skcipher_walk walk;
339 	int err;
340 
341 	if (req->cryptlen < AES_BLOCK_SIZE)
342 		return -EINVAL;
343 
344 	if (unlikely(tail)) {
345 		skcipher_request_set_tfm(&subreq, tfm);
346 		skcipher_request_set_callback(&subreq,
347 					      skcipher_request_flags(req),
348 					      NULL, NULL);
349 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
350 					   req->cryptlen - tail, req->iv);
351 		req = &subreq;
352 	}
353 
354 	err = skcipher_walk_virt(&walk, req, true);
355 	if (err)
356 		return err;
357 
358 	crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
359 
360 	while (walk.nbytes >= AES_BLOCK_SIZE) {
361 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
362 		int reorder_last_tweak = !encrypt && tail > 0;
363 
364 		if (walk.nbytes < walk.total) {
365 			blocks = round_down(blocks,
366 					    walk.stride / AES_BLOCK_SIZE);
367 			reorder_last_tweak = 0;
368 		}
369 
370 		kernel_neon_begin();
371 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
372 		   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
373 		kernel_neon_end();
374 		err = skcipher_walk_done(&walk,
375 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
376 	}
377 
378 	if (err || likely(!tail))
379 		return err;
380 
381 	/* handle ciphertext stealing */
382 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
383 				 AES_BLOCK_SIZE, 0);
384 	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
385 	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
386 
387 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
388 
389 	if (encrypt)
390 		crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
391 	else
392 		crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
393 
394 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
395 
396 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
397 				 AES_BLOCK_SIZE + tail, 1);
398 	return 0;
399 }
400 
401 static int xts_encrypt(struct skcipher_request *req)
402 {
403 	return __xts_crypt(req, true, aesbs_xts_encrypt);
404 }
405 
406 static int xts_decrypt(struct skcipher_request *req)
407 {
408 	return __xts_crypt(req, false, aesbs_xts_decrypt);
409 }
410 
411 static struct skcipher_alg aes_algs[] = { {
412 	.base.cra_name		= "__ecb(aes)",
413 	.base.cra_driver_name	= "__ecb-aes-neonbs",
414 	.base.cra_priority	= 250,
415 	.base.cra_blocksize	= AES_BLOCK_SIZE,
416 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
417 	.base.cra_module	= THIS_MODULE,
418 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
419 
420 	.min_keysize		= AES_MIN_KEY_SIZE,
421 	.max_keysize		= AES_MAX_KEY_SIZE,
422 	.walksize		= 8 * AES_BLOCK_SIZE,
423 	.setkey			= aesbs_setkey,
424 	.encrypt		= ecb_encrypt,
425 	.decrypt		= ecb_decrypt,
426 }, {
427 	.base.cra_name		= "__cbc(aes)",
428 	.base.cra_driver_name	= "__cbc-aes-neonbs",
429 	.base.cra_priority	= 250,
430 	.base.cra_blocksize	= AES_BLOCK_SIZE,
431 	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
432 	.base.cra_module	= THIS_MODULE,
433 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
434 	.base.cra_init		= cbc_init,
435 	.base.cra_exit		= cbc_exit,
436 
437 	.min_keysize		= AES_MIN_KEY_SIZE,
438 	.max_keysize		= AES_MAX_KEY_SIZE,
439 	.walksize		= 8 * AES_BLOCK_SIZE,
440 	.ivsize			= AES_BLOCK_SIZE,
441 	.setkey			= aesbs_cbc_setkey,
442 	.encrypt		= cbc_encrypt,
443 	.decrypt		= cbc_decrypt,
444 }, {
445 	.base.cra_name		= "__ctr(aes)",
446 	.base.cra_driver_name	= "__ctr-aes-neonbs",
447 	.base.cra_priority	= 250,
448 	.base.cra_blocksize	= 1,
449 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
450 	.base.cra_module	= THIS_MODULE,
451 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
452 
453 	.min_keysize		= AES_MIN_KEY_SIZE,
454 	.max_keysize		= AES_MAX_KEY_SIZE,
455 	.chunksize		= AES_BLOCK_SIZE,
456 	.walksize		= 8 * AES_BLOCK_SIZE,
457 	.ivsize			= AES_BLOCK_SIZE,
458 	.setkey			= aesbs_setkey,
459 	.encrypt		= ctr_encrypt,
460 	.decrypt		= ctr_encrypt,
461 }, {
462 	.base.cra_name		= "ctr(aes)",
463 	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
464 	.base.cra_priority	= 250 - 1,
465 	.base.cra_blocksize	= 1,
466 	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
467 	.base.cra_module	= THIS_MODULE,
468 
469 	.min_keysize		= AES_MIN_KEY_SIZE,
470 	.max_keysize		= AES_MAX_KEY_SIZE,
471 	.chunksize		= AES_BLOCK_SIZE,
472 	.walksize		= 8 * AES_BLOCK_SIZE,
473 	.ivsize			= AES_BLOCK_SIZE,
474 	.setkey			= aesbs_ctr_setkey_sync,
475 	.encrypt		= ctr_encrypt_sync,
476 	.decrypt		= ctr_encrypt_sync,
477 }, {
478 	.base.cra_name		= "__xts(aes)",
479 	.base.cra_driver_name	= "__xts-aes-neonbs",
480 	.base.cra_priority	= 250,
481 	.base.cra_blocksize	= AES_BLOCK_SIZE,
482 	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
483 	.base.cra_module	= THIS_MODULE,
484 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
485 	.base.cra_init		= xts_init,
486 	.base.cra_exit		= xts_exit,
487 
488 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
489 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
490 	.walksize		= 8 * AES_BLOCK_SIZE,
491 	.ivsize			= AES_BLOCK_SIZE,
492 	.setkey			= aesbs_xts_setkey,
493 	.encrypt		= xts_encrypt,
494 	.decrypt		= xts_decrypt,
495 } };
496 
497 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
498 
499 static void aes_exit(void)
500 {
501 	int i;
502 
503 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
504 		if (aes_simd_algs[i])
505 			simd_skcipher_free(aes_simd_algs[i]);
506 
507 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
508 }
509 
510 static int __init aes_init(void)
511 {
512 	struct simd_skcipher_alg *simd;
513 	const char *basename;
514 	const char *algname;
515 	const char *drvname;
516 	int err;
517 	int i;
518 
519 	if (!(elf_hwcap & HWCAP_NEON))
520 		return -ENODEV;
521 
522 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
523 	if (err)
524 		return err;
525 
526 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
527 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
528 			continue;
529 
530 		algname = aes_algs[i].base.cra_name + 2;
531 		drvname = aes_algs[i].base.cra_driver_name + 2;
532 		basename = aes_algs[i].base.cra_driver_name;
533 		simd = simd_skcipher_create_compat(algname, drvname, basename);
534 		err = PTR_ERR(simd);
535 		if (IS_ERR(simd))
536 			goto unregister_simds;
537 
538 		aes_simd_algs[i] = simd;
539 	}
540 	return 0;
541 
542 unregister_simds:
543 	aes_exit();
544 	return err;
545 }
546 
547 late_initcall(aes_init);
548 module_exit(aes_exit);
549