1 /*
2  * Bit sliced AES using NEON instructions
3  *
4  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10 
11 #include <asm/neon.h>
12 #include <crypto/aes.h>
13 #include <crypto/cbc.h>
14 #include <crypto/internal/simd.h>
15 #include <crypto/internal/skcipher.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18 
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
21 
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
26 
27 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
28 
29 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
30 				  int rounds, int blocks);
31 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
32 				  int rounds, int blocks);
33 
34 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
35 				  int rounds, int blocks, u8 iv[]);
36 
37 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
38 				  int rounds, int blocks, u8 ctr[], u8 final[]);
39 
40 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
41 				  int rounds, int blocks, u8 iv[]);
42 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
43 				  int rounds, int blocks, u8 iv[]);
44 
45 asmlinkage void __aes_arm_encrypt(const u32 rk[], int rounds, const u8 in[],
46 				  u8 out[]);
47 
48 struct aesbs_ctx {
49 	int	rounds;
50 	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
51 };
52 
53 struct aesbs_cbc_ctx {
54 	struct aesbs_ctx	key;
55 	u32			enc[AES_MAX_KEYLENGTH_U32];
56 };
57 
58 struct aesbs_xts_ctx {
59 	struct aesbs_ctx	key;
60 	u32			twkey[AES_MAX_KEYLENGTH_U32];
61 };
62 
63 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
64 			unsigned int key_len)
65 {
66 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
67 	struct crypto_aes_ctx rk;
68 	int err;
69 
70 	err = crypto_aes_expand_key(&rk, in_key, key_len);
71 	if (err)
72 		return err;
73 
74 	ctx->rounds = 6 + key_len / 4;
75 
76 	kernel_neon_begin();
77 	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
78 	kernel_neon_end();
79 
80 	return 0;
81 }
82 
83 static int __ecb_crypt(struct skcipher_request *req,
84 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
85 				  int rounds, int blocks))
86 {
87 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
88 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
89 	struct skcipher_walk walk;
90 	int err;
91 
92 	err = skcipher_walk_virt(&walk, req, true);
93 
94 	kernel_neon_begin();
95 	while (walk.nbytes >= AES_BLOCK_SIZE) {
96 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
97 
98 		if (walk.nbytes < walk.total)
99 			blocks = round_down(blocks,
100 					    walk.stride / AES_BLOCK_SIZE);
101 
102 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
103 		   ctx->rounds, blocks);
104 		err = skcipher_walk_done(&walk,
105 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
106 	}
107 	kernel_neon_end();
108 
109 	return err;
110 }
111 
112 static int ecb_encrypt(struct skcipher_request *req)
113 {
114 	return __ecb_crypt(req, aesbs_ecb_encrypt);
115 }
116 
117 static int ecb_decrypt(struct skcipher_request *req)
118 {
119 	return __ecb_crypt(req, aesbs_ecb_decrypt);
120 }
121 
122 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
123 			    unsigned int key_len)
124 {
125 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
126 	struct crypto_aes_ctx rk;
127 	int err;
128 
129 	err = crypto_aes_expand_key(&rk, in_key, key_len);
130 	if (err)
131 		return err;
132 
133 	ctx->key.rounds = 6 + key_len / 4;
134 
135 	memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
136 
137 	kernel_neon_begin();
138 	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
139 	kernel_neon_end();
140 
141 	return 0;
142 }
143 
144 static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
145 {
146 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
147 
148 	__aes_arm_encrypt(ctx->enc, ctx->key.rounds, src, dst);
149 }
150 
151 static int cbc_encrypt(struct skcipher_request *req)
152 {
153 	return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
154 }
155 
156 static int cbc_decrypt(struct skcipher_request *req)
157 {
158 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
159 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
160 	struct skcipher_walk walk;
161 	int err;
162 
163 	err = skcipher_walk_virt(&walk, req, true);
164 
165 	kernel_neon_begin();
166 	while (walk.nbytes >= AES_BLOCK_SIZE) {
167 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
168 
169 		if (walk.nbytes < walk.total)
170 			blocks = round_down(blocks,
171 					    walk.stride / AES_BLOCK_SIZE);
172 
173 		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
174 				  ctx->key.rk, ctx->key.rounds, blocks,
175 				  walk.iv);
176 		err = skcipher_walk_done(&walk,
177 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
178 	}
179 	kernel_neon_end();
180 
181 	return err;
182 }
183 
184 static int ctr_encrypt(struct skcipher_request *req)
185 {
186 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
187 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
188 	struct skcipher_walk walk;
189 	u8 buf[AES_BLOCK_SIZE];
190 	int err;
191 
192 	err = skcipher_walk_virt(&walk, req, true);
193 
194 	kernel_neon_begin();
195 	while (walk.nbytes > 0) {
196 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
197 		u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
198 
199 		if (walk.nbytes < walk.total) {
200 			blocks = round_down(blocks,
201 					    walk.stride / AES_BLOCK_SIZE);
202 			final = NULL;
203 		}
204 
205 		aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
206 				  ctx->rk, ctx->rounds, blocks, walk.iv, final);
207 
208 		if (final) {
209 			u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
210 			u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
211 
212 			if (dst != src)
213 				memcpy(dst, src, walk.total % AES_BLOCK_SIZE);
214 			crypto_xor(dst, final, walk.total % AES_BLOCK_SIZE);
215 
216 			err = skcipher_walk_done(&walk, 0);
217 			break;
218 		}
219 		err = skcipher_walk_done(&walk,
220 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
221 	}
222 	kernel_neon_end();
223 
224 	return err;
225 }
226 
227 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
228 			    unsigned int key_len)
229 {
230 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
231 	struct crypto_aes_ctx rk;
232 	int err;
233 
234 	err = xts_verify_key(tfm, in_key, key_len);
235 	if (err)
236 		return err;
237 
238 	key_len /= 2;
239 	err = crypto_aes_expand_key(&rk, in_key + key_len, key_len);
240 	if (err)
241 		return err;
242 
243 	memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
244 
245 	return aesbs_setkey(tfm, in_key, key_len);
246 }
247 
248 static int __xts_crypt(struct skcipher_request *req,
249 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
250 				  int rounds, int blocks, u8 iv[]))
251 {
252 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
253 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
254 	struct skcipher_walk walk;
255 	int err;
256 
257 	err = skcipher_walk_virt(&walk, req, true);
258 
259 	__aes_arm_encrypt(ctx->twkey, ctx->key.rounds, walk.iv, walk.iv);
260 
261 	kernel_neon_begin();
262 	while (walk.nbytes >= AES_BLOCK_SIZE) {
263 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
264 
265 		if (walk.nbytes < walk.total)
266 			blocks = round_down(blocks,
267 					    walk.stride / AES_BLOCK_SIZE);
268 
269 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
270 		   ctx->key.rounds, blocks, walk.iv);
271 		err = skcipher_walk_done(&walk,
272 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
273 	}
274 	kernel_neon_end();
275 
276 	return err;
277 }
278 
279 static int xts_encrypt(struct skcipher_request *req)
280 {
281 	return __xts_crypt(req, aesbs_xts_encrypt);
282 }
283 
284 static int xts_decrypt(struct skcipher_request *req)
285 {
286 	return __xts_crypt(req, aesbs_xts_decrypt);
287 }
288 
289 static struct skcipher_alg aes_algs[] = { {
290 	.base.cra_name		= "__ecb(aes)",
291 	.base.cra_driver_name	= "__ecb-aes-neonbs",
292 	.base.cra_priority	= 250,
293 	.base.cra_blocksize	= AES_BLOCK_SIZE,
294 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
295 	.base.cra_module	= THIS_MODULE,
296 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
297 
298 	.min_keysize		= AES_MIN_KEY_SIZE,
299 	.max_keysize		= AES_MAX_KEY_SIZE,
300 	.walksize		= 8 * AES_BLOCK_SIZE,
301 	.setkey			= aesbs_setkey,
302 	.encrypt		= ecb_encrypt,
303 	.decrypt		= ecb_decrypt,
304 }, {
305 	.base.cra_name		= "__cbc(aes)",
306 	.base.cra_driver_name	= "__cbc-aes-neonbs",
307 	.base.cra_priority	= 250,
308 	.base.cra_blocksize	= AES_BLOCK_SIZE,
309 	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
310 	.base.cra_module	= THIS_MODULE,
311 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
312 
313 	.min_keysize		= AES_MIN_KEY_SIZE,
314 	.max_keysize		= AES_MAX_KEY_SIZE,
315 	.walksize		= 8 * AES_BLOCK_SIZE,
316 	.ivsize			= AES_BLOCK_SIZE,
317 	.setkey			= aesbs_cbc_setkey,
318 	.encrypt		= cbc_encrypt,
319 	.decrypt		= cbc_decrypt,
320 }, {
321 	.base.cra_name		= "__ctr(aes)",
322 	.base.cra_driver_name	= "__ctr-aes-neonbs",
323 	.base.cra_priority	= 250,
324 	.base.cra_blocksize	= 1,
325 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
326 	.base.cra_module	= THIS_MODULE,
327 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
328 
329 	.min_keysize		= AES_MIN_KEY_SIZE,
330 	.max_keysize		= AES_MAX_KEY_SIZE,
331 	.chunksize		= AES_BLOCK_SIZE,
332 	.walksize		= 8 * AES_BLOCK_SIZE,
333 	.ivsize			= AES_BLOCK_SIZE,
334 	.setkey			= aesbs_setkey,
335 	.encrypt		= ctr_encrypt,
336 	.decrypt		= ctr_encrypt,
337 }, {
338 	.base.cra_name		= "__xts(aes)",
339 	.base.cra_driver_name	= "__xts-aes-neonbs",
340 	.base.cra_priority	= 250,
341 	.base.cra_blocksize	= AES_BLOCK_SIZE,
342 	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
343 	.base.cra_module	= THIS_MODULE,
344 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
345 
346 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
347 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
348 	.walksize		= 8 * AES_BLOCK_SIZE,
349 	.ivsize			= AES_BLOCK_SIZE,
350 	.setkey			= aesbs_xts_setkey,
351 	.encrypt		= xts_encrypt,
352 	.decrypt		= xts_decrypt,
353 } };
354 
355 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
356 
357 static void aes_exit(void)
358 {
359 	int i;
360 
361 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
362 		if (aes_simd_algs[i])
363 			simd_skcipher_free(aes_simd_algs[i]);
364 
365 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
366 }
367 
368 static int __init aes_init(void)
369 {
370 	struct simd_skcipher_alg *simd;
371 	const char *basename;
372 	const char *algname;
373 	const char *drvname;
374 	int err;
375 	int i;
376 
377 	if (!(elf_hwcap & HWCAP_NEON))
378 		return -ENODEV;
379 
380 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
381 	if (err)
382 		return err;
383 
384 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
385 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
386 			continue;
387 
388 		algname = aes_algs[i].base.cra_name + 2;
389 		drvname = aes_algs[i].base.cra_driver_name + 2;
390 		basename = aes_algs[i].base.cra_driver_name;
391 		simd = simd_skcipher_create_compat(algname, drvname, basename);
392 		err = PTR_ERR(simd);
393 		if (IS_ERR(simd))
394 			goto unregister_simds;
395 
396 		aes_simd_algs[i] = simd;
397 	}
398 	return 0;
399 
400 unregister_simds:
401 	aes_exit();
402 	return err;
403 }
404 
405 module_init(aes_init);
406 module_exit(aes_exit);
407