xref: /openbmc/linux/arch/arm/crypto/aes-ce-glue.c (revision 22d55f02)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * aes-ce-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/hwcap.h>
9 #include <asm/neon.h>
10 #include <crypto/aes.h>
11 #include <crypto/internal/simd.h>
12 #include <crypto/internal/skcipher.h>
13 #include <linux/cpufeature.h>
14 #include <linux/module.h>
15 #include <crypto/xts.h>
16 
17 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
20 
21 /* defined in aes-ce-core.S */
22 asmlinkage u32 ce_aes_sub(u32 input);
23 asmlinkage void ce_aes_invert(void *dst, void *src);
24 
25 asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
26 				   int rounds, int blocks);
27 asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
28 				   int rounds, int blocks);
29 
30 asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
31 				   int rounds, int blocks, u8 iv[]);
32 asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
33 				   int rounds, int blocks, u8 iv[]);
34 
35 asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
36 				   int rounds, int blocks, u8 ctr[]);
37 
38 asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
39 				   int rounds, int blocks, u8 iv[],
40 				   u8 const rk2[], int first);
41 asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
42 				   int rounds, int blocks, u8 iv[],
43 				   u8 const rk2[], int first);
44 
45 struct aes_block {
46 	u8 b[AES_BLOCK_SIZE];
47 };
48 
49 static int num_rounds(struct crypto_aes_ctx *ctx)
50 {
51 	/*
52 	 * # of rounds specified by AES:
53 	 * 128 bit key		10 rounds
54 	 * 192 bit key		12 rounds
55 	 * 256 bit key		14 rounds
56 	 * => n byte key	=> 6 + (n/4) rounds
57 	 */
58 	return 6 + ctx->key_length / 4;
59 }
60 
61 static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
62 			    unsigned int key_len)
63 {
64 	/*
65 	 * The AES key schedule round constants
66 	 */
67 	static u8 const rcon[] = {
68 		0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
69 	};
70 
71 	u32 kwords = key_len / sizeof(u32);
72 	struct aes_block *key_enc, *key_dec;
73 	int i, j;
74 
75 	if (key_len != AES_KEYSIZE_128 &&
76 	    key_len != AES_KEYSIZE_192 &&
77 	    key_len != AES_KEYSIZE_256)
78 		return -EINVAL;
79 
80 	memcpy(ctx->key_enc, in_key, key_len);
81 	ctx->key_length = key_len;
82 
83 	kernel_neon_begin();
84 	for (i = 0; i < sizeof(rcon); i++) {
85 		u32 *rki = ctx->key_enc + (i * kwords);
86 		u32 *rko = rki + kwords;
87 
88 #ifndef CONFIG_CPU_BIG_ENDIAN
89 		rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
90 		rko[0] = rko[0] ^ rki[0] ^ rcon[i];
91 #else
92 		rko[0] = rol32(ce_aes_sub(rki[kwords - 1]), 8);
93 		rko[0] = rko[0] ^ rki[0] ^ (rcon[i] << 24);
94 #endif
95 		rko[1] = rko[0] ^ rki[1];
96 		rko[2] = rko[1] ^ rki[2];
97 		rko[3] = rko[2] ^ rki[3];
98 
99 		if (key_len == AES_KEYSIZE_192) {
100 			if (i >= 7)
101 				break;
102 			rko[4] = rko[3] ^ rki[4];
103 			rko[5] = rko[4] ^ rki[5];
104 		} else if (key_len == AES_KEYSIZE_256) {
105 			if (i >= 6)
106 				break;
107 			rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
108 			rko[5] = rko[4] ^ rki[5];
109 			rko[6] = rko[5] ^ rki[6];
110 			rko[7] = rko[6] ^ rki[7];
111 		}
112 	}
113 
114 	/*
115 	 * Generate the decryption keys for the Equivalent Inverse Cipher.
116 	 * This involves reversing the order of the round keys, and applying
117 	 * the Inverse Mix Columns transformation on all but the first and
118 	 * the last one.
119 	 */
120 	key_enc = (struct aes_block *)ctx->key_enc;
121 	key_dec = (struct aes_block *)ctx->key_dec;
122 	j = num_rounds(ctx);
123 
124 	key_dec[0] = key_enc[j];
125 	for (i = 1, j--; j > 0; i++, j--)
126 		ce_aes_invert(key_dec + i, key_enc + j);
127 	key_dec[i] = key_enc[0];
128 
129 	kernel_neon_end();
130 	return 0;
131 }
132 
133 static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
134 			 unsigned int key_len)
135 {
136 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
137 	int ret;
138 
139 	ret = ce_aes_expandkey(ctx, in_key, key_len);
140 	if (!ret)
141 		return 0;
142 
143 	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
144 	return -EINVAL;
145 }
146 
147 struct crypto_aes_xts_ctx {
148 	struct crypto_aes_ctx key1;
149 	struct crypto_aes_ctx __aligned(8) key2;
150 };
151 
152 static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
153 		       unsigned int key_len)
154 {
155 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
156 	int ret;
157 
158 	ret = xts_verify_key(tfm, in_key, key_len);
159 	if (ret)
160 		return ret;
161 
162 	ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
163 	if (!ret)
164 		ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
165 				       key_len / 2);
166 	if (!ret)
167 		return 0;
168 
169 	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
170 	return -EINVAL;
171 }
172 
173 static int ecb_encrypt(struct skcipher_request *req)
174 {
175 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
176 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
177 	struct skcipher_walk walk;
178 	unsigned int blocks;
179 	int err;
180 
181 	err = skcipher_walk_virt(&walk, req, true);
182 
183 	kernel_neon_begin();
184 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
185 		ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
186 				   (u8 *)ctx->key_enc, num_rounds(ctx), blocks);
187 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
188 	}
189 	kernel_neon_end();
190 	return err;
191 }
192 
193 static int ecb_decrypt(struct skcipher_request *req)
194 {
195 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
196 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
197 	struct skcipher_walk walk;
198 	unsigned int blocks;
199 	int err;
200 
201 	err = skcipher_walk_virt(&walk, req, true);
202 
203 	kernel_neon_begin();
204 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
205 		ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
206 				   (u8 *)ctx->key_dec, num_rounds(ctx), blocks);
207 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
208 	}
209 	kernel_neon_end();
210 	return err;
211 }
212 
213 static int cbc_encrypt(struct skcipher_request *req)
214 {
215 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
216 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
217 	struct skcipher_walk walk;
218 	unsigned int blocks;
219 	int err;
220 
221 	err = skcipher_walk_virt(&walk, req, true);
222 
223 	kernel_neon_begin();
224 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
225 		ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
226 				   (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
227 				   walk.iv);
228 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
229 	}
230 	kernel_neon_end();
231 	return err;
232 }
233 
234 static int cbc_decrypt(struct skcipher_request *req)
235 {
236 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
237 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
238 	struct skcipher_walk walk;
239 	unsigned int blocks;
240 	int err;
241 
242 	err = skcipher_walk_virt(&walk, req, true);
243 
244 	kernel_neon_begin();
245 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
246 		ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
247 				   (u8 *)ctx->key_dec, num_rounds(ctx), blocks,
248 				   walk.iv);
249 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
250 	}
251 	kernel_neon_end();
252 	return err;
253 }
254 
255 static int ctr_encrypt(struct skcipher_request *req)
256 {
257 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
258 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
259 	struct skcipher_walk walk;
260 	int err, blocks;
261 
262 	err = skcipher_walk_virt(&walk, req, true);
263 
264 	kernel_neon_begin();
265 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
266 		ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
267 				   (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
268 				   walk.iv);
269 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
270 	}
271 	if (walk.nbytes) {
272 		u8 __aligned(8) tail[AES_BLOCK_SIZE];
273 		unsigned int nbytes = walk.nbytes;
274 		u8 *tdst = walk.dst.virt.addr;
275 		u8 *tsrc = walk.src.virt.addr;
276 
277 		/*
278 		 * Tell aes_ctr_encrypt() to process a tail block.
279 		 */
280 		blocks = -1;
281 
282 		ce_aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc,
283 				   num_rounds(ctx), blocks, walk.iv);
284 		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
285 		err = skcipher_walk_done(&walk, 0);
286 	}
287 	kernel_neon_end();
288 
289 	return err;
290 }
291 
292 static int xts_encrypt(struct skcipher_request *req)
293 {
294 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
295 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
296 	int err, first, rounds = num_rounds(&ctx->key1);
297 	struct skcipher_walk walk;
298 	unsigned int blocks;
299 
300 	err = skcipher_walk_virt(&walk, req, true);
301 
302 	kernel_neon_begin();
303 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
304 		ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
305 				   (u8 *)ctx->key1.key_enc, rounds, blocks,
306 				   walk.iv, (u8 *)ctx->key2.key_enc, first);
307 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
308 	}
309 	kernel_neon_end();
310 
311 	return err;
312 }
313 
314 static int xts_decrypt(struct skcipher_request *req)
315 {
316 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
317 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
318 	int err, first, rounds = num_rounds(&ctx->key1);
319 	struct skcipher_walk walk;
320 	unsigned int blocks;
321 
322 	err = skcipher_walk_virt(&walk, req, true);
323 
324 	kernel_neon_begin();
325 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
326 		ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
327 				   (u8 *)ctx->key1.key_dec, rounds, blocks,
328 				   walk.iv, (u8 *)ctx->key2.key_enc, first);
329 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
330 	}
331 	kernel_neon_end();
332 
333 	return err;
334 }
335 
336 static struct skcipher_alg aes_algs[] = { {
337 	.base = {
338 		.cra_name		= "__ecb(aes)",
339 		.cra_driver_name	= "__ecb-aes-ce",
340 		.cra_priority		= 300,
341 		.cra_flags		= CRYPTO_ALG_INTERNAL,
342 		.cra_blocksize		= AES_BLOCK_SIZE,
343 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
344 		.cra_module		= THIS_MODULE,
345 	},
346 	.min_keysize	= AES_MIN_KEY_SIZE,
347 	.max_keysize	= AES_MAX_KEY_SIZE,
348 	.setkey		= ce_aes_setkey,
349 	.encrypt	= ecb_encrypt,
350 	.decrypt	= ecb_decrypt,
351 }, {
352 	.base = {
353 		.cra_name		= "__cbc(aes)",
354 		.cra_driver_name	= "__cbc-aes-ce",
355 		.cra_priority		= 300,
356 		.cra_flags		= CRYPTO_ALG_INTERNAL,
357 		.cra_blocksize		= AES_BLOCK_SIZE,
358 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
359 		.cra_module		= THIS_MODULE,
360 	},
361 	.min_keysize	= AES_MIN_KEY_SIZE,
362 	.max_keysize	= AES_MAX_KEY_SIZE,
363 	.ivsize		= AES_BLOCK_SIZE,
364 	.setkey		= ce_aes_setkey,
365 	.encrypt	= cbc_encrypt,
366 	.decrypt	= cbc_decrypt,
367 }, {
368 	.base = {
369 		.cra_name		= "__ctr(aes)",
370 		.cra_driver_name	= "__ctr-aes-ce",
371 		.cra_priority		= 300,
372 		.cra_flags		= CRYPTO_ALG_INTERNAL,
373 		.cra_blocksize		= 1,
374 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
375 		.cra_module		= THIS_MODULE,
376 	},
377 	.min_keysize	= AES_MIN_KEY_SIZE,
378 	.max_keysize	= AES_MAX_KEY_SIZE,
379 	.ivsize		= AES_BLOCK_SIZE,
380 	.chunksize	= AES_BLOCK_SIZE,
381 	.setkey		= ce_aes_setkey,
382 	.encrypt	= ctr_encrypt,
383 	.decrypt	= ctr_encrypt,
384 }, {
385 	.base = {
386 		.cra_name		= "__xts(aes)",
387 		.cra_driver_name	= "__xts-aes-ce",
388 		.cra_priority		= 300,
389 		.cra_flags		= CRYPTO_ALG_INTERNAL,
390 		.cra_blocksize		= AES_BLOCK_SIZE,
391 		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
392 		.cra_module		= THIS_MODULE,
393 	},
394 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
395 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
396 	.ivsize		= AES_BLOCK_SIZE,
397 	.setkey		= xts_set_key,
398 	.encrypt	= xts_encrypt,
399 	.decrypt	= xts_decrypt,
400 } };
401 
402 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
403 
404 static void aes_exit(void)
405 {
406 	int i;
407 
408 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
409 		simd_skcipher_free(aes_simd_algs[i]);
410 
411 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
412 }
413 
414 static int __init aes_init(void)
415 {
416 	struct simd_skcipher_alg *simd;
417 	const char *basename;
418 	const char *algname;
419 	const char *drvname;
420 	int err;
421 	int i;
422 
423 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
424 	if (err)
425 		return err;
426 
427 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
428 		algname = aes_algs[i].base.cra_name + 2;
429 		drvname = aes_algs[i].base.cra_driver_name + 2;
430 		basename = aes_algs[i].base.cra_driver_name;
431 		simd = simd_skcipher_create_compat(algname, drvname, basename);
432 		err = PTR_ERR(simd);
433 		if (IS_ERR(simd))
434 			goto unregister_simds;
435 
436 		aes_simd_algs[i] = simd;
437 	}
438 
439 	return 0;
440 
441 unregister_simds:
442 	aes_exit();
443 	return err;
444 }
445 
446 module_cpu_feature_match(AES, aes_init);
447 module_exit(aes_exit);
448