xref: /openbmc/linux/arch/arm64/crypto/aes-glue.c (revision 8730046c)
1 /*
2  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
3  *
4  * Copyright (C) 2013 Linaro Ltd <ard.biesheuvel@linaro.org>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10 
11 #include <asm/neon.h>
12 #include <asm/hwcap.h>
13 #include <crypto/aes.h>
14 #include <crypto/internal/simd.h>
15 #include <crypto/internal/skcipher.h>
16 #include <linux/module.h>
17 #include <linux/cpufeature.h>
18 #include <crypto/xts.h>
19 
20 #include "aes-ce-setkey.h"
21 
22 #ifdef USE_V8_CRYPTO_EXTENSIONS
23 #define MODE			"ce"
24 #define PRIO			300
25 #define aes_setkey		ce_aes_setkey
26 #define aes_expandkey		ce_aes_expandkey
27 #define aes_ecb_encrypt		ce_aes_ecb_encrypt
28 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
29 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
30 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
31 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
32 #define aes_xts_encrypt		ce_aes_xts_encrypt
33 #define aes_xts_decrypt		ce_aes_xts_decrypt
34 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
35 #else
36 #define MODE			"neon"
37 #define PRIO			200
38 #define aes_setkey		crypto_aes_set_key
39 #define aes_expandkey		crypto_aes_expand_key
40 #define aes_ecb_encrypt		neon_aes_ecb_encrypt
41 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
42 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
43 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
44 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
45 #define aes_xts_encrypt		neon_aes_xts_encrypt
46 #define aes_xts_decrypt		neon_aes_xts_decrypt
47 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
48 MODULE_ALIAS_CRYPTO("ecb(aes)");
49 MODULE_ALIAS_CRYPTO("cbc(aes)");
50 MODULE_ALIAS_CRYPTO("ctr(aes)");
51 MODULE_ALIAS_CRYPTO("xts(aes)");
52 #endif
53 
54 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
55 MODULE_LICENSE("GPL v2");
56 
57 /* defined in aes-modes.S */
58 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
59 				int rounds, int blocks, int first);
60 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
61 				int rounds, int blocks, int first);
62 
63 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
64 				int rounds, int blocks, u8 iv[], int first);
65 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
66 				int rounds, int blocks, u8 iv[], int first);
67 
68 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
69 				int rounds, int blocks, u8 ctr[], int first);
70 
71 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
72 				int rounds, int blocks, u8 const rk2[], u8 iv[],
73 				int first);
74 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
75 				int rounds, int blocks, u8 const rk2[], u8 iv[],
76 				int first);
77 
78 struct crypto_aes_xts_ctx {
79 	struct crypto_aes_ctx key1;
80 	struct crypto_aes_ctx __aligned(8) key2;
81 };
82 
83 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
84 			       unsigned int key_len)
85 {
86 	return aes_setkey(crypto_skcipher_tfm(tfm), in_key, key_len);
87 }
88 
89 static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
90 		       unsigned int key_len)
91 {
92 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
93 	int ret;
94 
95 	ret = xts_verify_key(tfm, in_key, key_len);
96 	if (ret)
97 		return ret;
98 
99 	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
100 	if (!ret)
101 		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
102 				    key_len / 2);
103 	if (!ret)
104 		return 0;
105 
106 	crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
107 	return -EINVAL;
108 }
109 
110 static int ecb_encrypt(struct skcipher_request *req)
111 {
112 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
113 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
114 	int err, first, rounds = 6 + ctx->key_length / 4;
115 	struct skcipher_walk walk;
116 	unsigned int blocks;
117 
118 	err = skcipher_walk_virt(&walk, req, true);
119 
120 	kernel_neon_begin();
121 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
122 		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
123 				(u8 *)ctx->key_enc, rounds, blocks, first);
124 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
125 	}
126 	kernel_neon_end();
127 	return err;
128 }
129 
130 static int ecb_decrypt(struct skcipher_request *req)
131 {
132 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
133 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
134 	int err, first, rounds = 6 + ctx->key_length / 4;
135 	struct skcipher_walk walk;
136 	unsigned int blocks;
137 
138 	err = skcipher_walk_virt(&walk, req, true);
139 
140 	kernel_neon_begin();
141 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
142 		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
143 				(u8 *)ctx->key_dec, rounds, blocks, first);
144 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
145 	}
146 	kernel_neon_end();
147 	return err;
148 }
149 
150 static int cbc_encrypt(struct skcipher_request *req)
151 {
152 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
153 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
154 	int err, first, rounds = 6 + ctx->key_length / 4;
155 	struct skcipher_walk walk;
156 	unsigned int blocks;
157 
158 	err = skcipher_walk_virt(&walk, req, true);
159 
160 	kernel_neon_begin();
161 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
162 		aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
163 				(u8 *)ctx->key_enc, rounds, blocks, walk.iv,
164 				first);
165 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
166 	}
167 	kernel_neon_end();
168 	return err;
169 }
170 
171 static int cbc_decrypt(struct skcipher_request *req)
172 {
173 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
174 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
175 	int err, first, rounds = 6 + ctx->key_length / 4;
176 	struct skcipher_walk walk;
177 	unsigned int blocks;
178 
179 	err = skcipher_walk_virt(&walk, req, true);
180 
181 	kernel_neon_begin();
182 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
183 		aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
184 				(u8 *)ctx->key_dec, rounds, blocks, walk.iv,
185 				first);
186 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
187 	}
188 	kernel_neon_end();
189 	return err;
190 }
191 
192 static int ctr_encrypt(struct skcipher_request *req)
193 {
194 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
195 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
196 	int err, first, rounds = 6 + ctx->key_length / 4;
197 	struct skcipher_walk walk;
198 	int blocks;
199 
200 	err = skcipher_walk_virt(&walk, req, true);
201 
202 	first = 1;
203 	kernel_neon_begin();
204 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
205 		aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
206 				(u8 *)ctx->key_enc, rounds, blocks, walk.iv,
207 				first);
208 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
209 		first = 0;
210 	}
211 	if (walk.nbytes) {
212 		u8 __aligned(8) tail[AES_BLOCK_SIZE];
213 		unsigned int nbytes = walk.nbytes;
214 		u8 *tdst = walk.dst.virt.addr;
215 		u8 *tsrc = walk.src.virt.addr;
216 
217 		/*
218 		 * Minimum alignment is 8 bytes, so if nbytes is <= 8, we need
219 		 * to tell aes_ctr_encrypt() to only read half a block.
220 		 */
221 		blocks = (nbytes <= 8) ? -1 : 1;
222 
223 		aes_ctr_encrypt(tail, tsrc, (u8 *)ctx->key_enc, rounds,
224 				blocks, walk.iv, first);
225 		memcpy(tdst, tail, nbytes);
226 		err = skcipher_walk_done(&walk, 0);
227 	}
228 	kernel_neon_end();
229 
230 	return err;
231 }
232 
233 static int xts_encrypt(struct skcipher_request *req)
234 {
235 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
237 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
238 	struct skcipher_walk walk;
239 	unsigned int blocks;
240 
241 	err = skcipher_walk_virt(&walk, req, true);
242 
243 	kernel_neon_begin();
244 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
245 		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
246 				(u8 *)ctx->key1.key_enc, rounds, blocks,
247 				(u8 *)ctx->key2.key_enc, walk.iv, first);
248 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
249 	}
250 	kernel_neon_end();
251 
252 	return err;
253 }
254 
255 static int xts_decrypt(struct skcipher_request *req)
256 {
257 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
258 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
259 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
260 	struct skcipher_walk walk;
261 	unsigned int blocks;
262 
263 	err = skcipher_walk_virt(&walk, req, true);
264 
265 	kernel_neon_begin();
266 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
267 		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
268 				(u8 *)ctx->key1.key_dec, rounds, blocks,
269 				(u8 *)ctx->key2.key_enc, walk.iv, first);
270 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
271 	}
272 	kernel_neon_end();
273 
274 	return err;
275 }
276 
277 static struct skcipher_alg aes_algs[] = { {
278 	.base = {
279 		.cra_name		= "__ecb(aes)",
280 		.cra_driver_name	= "__ecb-aes-" MODE,
281 		.cra_priority		= PRIO,
282 		.cra_flags		= CRYPTO_ALG_INTERNAL,
283 		.cra_blocksize		= AES_BLOCK_SIZE,
284 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
285 		.cra_alignmask		= 7,
286 		.cra_module		= THIS_MODULE,
287 	},
288 	.min_keysize	= AES_MIN_KEY_SIZE,
289 	.max_keysize	= AES_MAX_KEY_SIZE,
290 	.setkey		= skcipher_aes_setkey,
291 	.encrypt	= ecb_encrypt,
292 	.decrypt	= ecb_decrypt,
293 }, {
294 	.base = {
295 		.cra_name		= "__cbc(aes)",
296 		.cra_driver_name	= "__cbc-aes-" MODE,
297 		.cra_priority		= PRIO,
298 		.cra_flags		= CRYPTO_ALG_INTERNAL,
299 		.cra_blocksize		= AES_BLOCK_SIZE,
300 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
301 		.cra_alignmask		= 7,
302 		.cra_module		= THIS_MODULE,
303 	},
304 	.min_keysize	= AES_MIN_KEY_SIZE,
305 	.max_keysize	= AES_MAX_KEY_SIZE,
306 	.ivsize		= AES_BLOCK_SIZE,
307 	.setkey		= skcipher_aes_setkey,
308 	.encrypt	= cbc_encrypt,
309 	.decrypt	= cbc_decrypt,
310 }, {
311 	.base = {
312 		.cra_name		= "__ctr(aes)",
313 		.cra_driver_name	= "__ctr-aes-" MODE,
314 		.cra_priority		= PRIO,
315 		.cra_flags		= CRYPTO_ALG_INTERNAL,
316 		.cra_blocksize		= 1,
317 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
318 		.cra_alignmask		= 7,
319 		.cra_module		= THIS_MODULE,
320 	},
321 	.min_keysize	= AES_MIN_KEY_SIZE,
322 	.max_keysize	= AES_MAX_KEY_SIZE,
323 	.ivsize		= AES_BLOCK_SIZE,
324 	.chunksize	= AES_BLOCK_SIZE,
325 	.setkey		= skcipher_aes_setkey,
326 	.encrypt	= ctr_encrypt,
327 	.decrypt	= ctr_encrypt,
328 }, {
329 	.base = {
330 		.cra_name		= "__xts(aes)",
331 		.cra_driver_name	= "__xts-aes-" MODE,
332 		.cra_priority		= PRIO,
333 		.cra_flags		= CRYPTO_ALG_INTERNAL,
334 		.cra_blocksize		= AES_BLOCK_SIZE,
335 		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
336 		.cra_alignmask		= 7,
337 		.cra_module		= THIS_MODULE,
338 	},
339 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
340 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
341 	.ivsize		= AES_BLOCK_SIZE,
342 	.setkey		= xts_set_key,
343 	.encrypt	= xts_encrypt,
344 	.decrypt	= xts_decrypt,
345 } };
346 
347 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
348 
349 static void aes_exit(void)
350 {
351 	int i;
352 
353 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
354 		simd_skcipher_free(aes_simd_algs[i]);
355 
356 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
357 }
358 
359 static int __init aes_init(void)
360 {
361 	struct simd_skcipher_alg *simd;
362 	const char *basename;
363 	const char *algname;
364 	const char *drvname;
365 	int err;
366 	int i;
367 
368 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
369 	if (err)
370 		return err;
371 
372 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
373 		algname = aes_algs[i].base.cra_name + 2;
374 		drvname = aes_algs[i].base.cra_driver_name + 2;
375 		basename = aes_algs[i].base.cra_driver_name;
376 		simd = simd_skcipher_create_compat(algname, drvname, basename);
377 		err = PTR_ERR(simd);
378 		if (IS_ERR(simd))
379 			goto unregister_simds;
380 
381 		aes_simd_algs[i] = simd;
382 	}
383 
384 	return 0;
385 
386 unregister_simds:
387 	aes_exit();
388 	return err;
389 }
390 
391 #ifdef USE_V8_CRYPTO_EXTENSIONS
392 module_cpu_feature_match(AES, aes_init);
393 #else
394 module_init(aes_init);
395 #endif
396 module_exit(aes_exit);
397