1 /* 2 * Bit sliced AES using NEON instructions 3 * 4 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 5 * 6 * This program is free software; you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 2 as 8 * published by the Free Software Foundation. 9 */ 10 11 #include <asm/neon.h> 12 #include <asm/simd.h> 13 #include <crypto/aes.h> 14 #include <crypto/internal/simd.h> 15 #include <crypto/internal/skcipher.h> 16 #include <crypto/xts.h> 17 #include <linux/module.h> 18 19 #include "aes-ctr-fallback.h" 20 21 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 22 MODULE_LICENSE("GPL v2"); 23 24 MODULE_ALIAS_CRYPTO("ecb(aes)"); 25 MODULE_ALIAS_CRYPTO("cbc(aes)"); 26 MODULE_ALIAS_CRYPTO("ctr(aes)"); 27 MODULE_ALIAS_CRYPTO("xts(aes)"); 28 29 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 30 31 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 32 int rounds, int blocks); 33 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks); 35 36 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[]); 38 39 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[], u8 final[]); 41 42 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 43 int rounds, int blocks, u8 iv[]); 44 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 45 int rounds, int blocks, u8 iv[]); 46 47 /* borrowed from aes-neon-blk.ko */ 48 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 49 int rounds, int blocks); 50 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 51 int rounds, int blocks, u8 iv[]); 52 53 struct aesbs_ctx { 54 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 55 int rounds; 56 } __aligned(AES_BLOCK_SIZE); 57 58 struct aesbs_cbc_ctx { 59 struct aesbs_ctx key; 60 u32 enc[AES_MAX_KEYLENGTH_U32]; 61 }; 62 63 struct aesbs_ctr_ctx { 64 struct aesbs_ctx key; /* must be first member */ 65 struct crypto_aes_ctx fallback; 66 }; 67 68 struct aesbs_xts_ctx { 69 struct aesbs_ctx key; 70 u32 twkey[AES_MAX_KEYLENGTH_U32]; 71 }; 72 73 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 74 unsigned int key_len) 75 { 76 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 77 struct crypto_aes_ctx rk; 78 int err; 79 80 err = crypto_aes_expand_key(&rk, in_key, key_len); 81 if (err) 82 return err; 83 84 ctx->rounds = 6 + key_len / 4; 85 86 kernel_neon_begin(); 87 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 88 kernel_neon_end(); 89 90 return 0; 91 } 92 93 static int __ecb_crypt(struct skcipher_request *req, 94 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 95 int rounds, int blocks)) 96 { 97 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 98 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 99 struct skcipher_walk walk; 100 int err; 101 102 err = skcipher_walk_virt(&walk, req, false); 103 104 while (walk.nbytes >= AES_BLOCK_SIZE) { 105 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 106 107 if (walk.nbytes < walk.total) 108 blocks = round_down(blocks, 109 walk.stride / AES_BLOCK_SIZE); 110 111 kernel_neon_begin(); 112 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 113 ctx->rounds, blocks); 114 kernel_neon_end(); 115 err = skcipher_walk_done(&walk, 116 walk.nbytes - blocks * AES_BLOCK_SIZE); 117 } 118 119 return err; 120 } 121 122 static int ecb_encrypt(struct skcipher_request *req) 123 { 124 return __ecb_crypt(req, aesbs_ecb_encrypt); 125 } 126 127 static int ecb_decrypt(struct skcipher_request *req) 128 { 129 return __ecb_crypt(req, aesbs_ecb_decrypt); 130 } 131 132 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 133 unsigned int key_len) 134 { 135 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 136 struct crypto_aes_ctx rk; 137 int err; 138 139 err = crypto_aes_expand_key(&rk, in_key, key_len); 140 if (err) 141 return err; 142 143 ctx->key.rounds = 6 + key_len / 4; 144 145 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 146 147 kernel_neon_begin(); 148 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 149 kernel_neon_end(); 150 151 return 0; 152 } 153 154 static int cbc_encrypt(struct skcipher_request *req) 155 { 156 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 157 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 158 struct skcipher_walk walk; 159 int err; 160 161 err = skcipher_walk_virt(&walk, req, false); 162 163 while (walk.nbytes >= AES_BLOCK_SIZE) { 164 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 165 166 /* fall back to the non-bitsliced NEON implementation */ 167 kernel_neon_begin(); 168 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 169 ctx->enc, ctx->key.rounds, blocks, 170 walk.iv); 171 kernel_neon_end(); 172 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 173 } 174 return err; 175 } 176 177 static int cbc_decrypt(struct skcipher_request *req) 178 { 179 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 180 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 181 struct skcipher_walk walk; 182 int err; 183 184 err = skcipher_walk_virt(&walk, req, false); 185 186 while (walk.nbytes >= AES_BLOCK_SIZE) { 187 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 188 189 if (walk.nbytes < walk.total) 190 blocks = round_down(blocks, 191 walk.stride / AES_BLOCK_SIZE); 192 193 kernel_neon_begin(); 194 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 195 ctx->key.rk, ctx->key.rounds, blocks, 196 walk.iv); 197 kernel_neon_end(); 198 err = skcipher_walk_done(&walk, 199 walk.nbytes - blocks * AES_BLOCK_SIZE); 200 } 201 202 return err; 203 } 204 205 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key, 206 unsigned int key_len) 207 { 208 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 209 int err; 210 211 err = crypto_aes_expand_key(&ctx->fallback, in_key, key_len); 212 if (err) 213 return err; 214 215 ctx->key.rounds = 6 + key_len / 4; 216 217 kernel_neon_begin(); 218 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 219 kernel_neon_end(); 220 221 return 0; 222 } 223 224 static int ctr_encrypt(struct skcipher_request *req) 225 { 226 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 227 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 228 struct skcipher_walk walk; 229 u8 buf[AES_BLOCK_SIZE]; 230 int err; 231 232 err = skcipher_walk_virt(&walk, req, false); 233 234 while (walk.nbytes > 0) { 235 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 236 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL; 237 238 if (walk.nbytes < walk.total) { 239 blocks = round_down(blocks, 240 walk.stride / AES_BLOCK_SIZE); 241 final = NULL; 242 } 243 244 kernel_neon_begin(); 245 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 246 ctx->rk, ctx->rounds, blocks, walk.iv, final); 247 kernel_neon_end(); 248 249 if (final) { 250 u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE; 251 u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE; 252 253 crypto_xor_cpy(dst, src, final, 254 walk.total % AES_BLOCK_SIZE); 255 256 err = skcipher_walk_done(&walk, 0); 257 break; 258 } 259 err = skcipher_walk_done(&walk, 260 walk.nbytes - blocks * AES_BLOCK_SIZE); 261 } 262 return err; 263 } 264 265 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 266 unsigned int key_len) 267 { 268 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 269 struct crypto_aes_ctx rk; 270 int err; 271 272 err = xts_verify_key(tfm, in_key, key_len); 273 if (err) 274 return err; 275 276 key_len /= 2; 277 err = crypto_aes_expand_key(&rk, in_key + key_len, key_len); 278 if (err) 279 return err; 280 281 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 282 283 return aesbs_setkey(tfm, in_key, key_len); 284 } 285 286 static int ctr_encrypt_sync(struct skcipher_request *req) 287 { 288 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 289 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 290 291 if (!crypto_simd_usable()) 292 return aes_ctr_encrypt_fallback(&ctx->fallback, req); 293 294 return ctr_encrypt(req); 295 } 296 297 static int __xts_crypt(struct skcipher_request *req, 298 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 299 int rounds, int blocks, u8 iv[])) 300 { 301 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 302 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 303 struct skcipher_walk walk; 304 int err; 305 306 err = skcipher_walk_virt(&walk, req, false); 307 if (err) 308 return err; 309 310 kernel_neon_begin(); 311 neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1); 312 kernel_neon_end(); 313 314 while (walk.nbytes >= AES_BLOCK_SIZE) { 315 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 316 317 if (walk.nbytes < walk.total) 318 blocks = round_down(blocks, 319 walk.stride / AES_BLOCK_SIZE); 320 321 kernel_neon_begin(); 322 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, 323 ctx->key.rounds, blocks, walk.iv); 324 kernel_neon_end(); 325 err = skcipher_walk_done(&walk, 326 walk.nbytes - blocks * AES_BLOCK_SIZE); 327 } 328 return err; 329 } 330 331 static int xts_encrypt(struct skcipher_request *req) 332 { 333 return __xts_crypt(req, aesbs_xts_encrypt); 334 } 335 336 static int xts_decrypt(struct skcipher_request *req) 337 { 338 return __xts_crypt(req, aesbs_xts_decrypt); 339 } 340 341 static struct skcipher_alg aes_algs[] = { { 342 .base.cra_name = "__ecb(aes)", 343 .base.cra_driver_name = "__ecb-aes-neonbs", 344 .base.cra_priority = 250, 345 .base.cra_blocksize = AES_BLOCK_SIZE, 346 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 347 .base.cra_module = THIS_MODULE, 348 .base.cra_flags = CRYPTO_ALG_INTERNAL, 349 350 .min_keysize = AES_MIN_KEY_SIZE, 351 .max_keysize = AES_MAX_KEY_SIZE, 352 .walksize = 8 * AES_BLOCK_SIZE, 353 .setkey = aesbs_setkey, 354 .encrypt = ecb_encrypt, 355 .decrypt = ecb_decrypt, 356 }, { 357 .base.cra_name = "__cbc(aes)", 358 .base.cra_driver_name = "__cbc-aes-neonbs", 359 .base.cra_priority = 250, 360 .base.cra_blocksize = AES_BLOCK_SIZE, 361 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 362 .base.cra_module = THIS_MODULE, 363 .base.cra_flags = CRYPTO_ALG_INTERNAL, 364 365 .min_keysize = AES_MIN_KEY_SIZE, 366 .max_keysize = AES_MAX_KEY_SIZE, 367 .walksize = 8 * AES_BLOCK_SIZE, 368 .ivsize = AES_BLOCK_SIZE, 369 .setkey = aesbs_cbc_setkey, 370 .encrypt = cbc_encrypt, 371 .decrypt = cbc_decrypt, 372 }, { 373 .base.cra_name = "__ctr(aes)", 374 .base.cra_driver_name = "__ctr-aes-neonbs", 375 .base.cra_priority = 250, 376 .base.cra_blocksize = 1, 377 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 378 .base.cra_module = THIS_MODULE, 379 .base.cra_flags = CRYPTO_ALG_INTERNAL, 380 381 .min_keysize = AES_MIN_KEY_SIZE, 382 .max_keysize = AES_MAX_KEY_SIZE, 383 .chunksize = AES_BLOCK_SIZE, 384 .walksize = 8 * AES_BLOCK_SIZE, 385 .ivsize = AES_BLOCK_SIZE, 386 .setkey = aesbs_setkey, 387 .encrypt = ctr_encrypt, 388 .decrypt = ctr_encrypt, 389 }, { 390 .base.cra_name = "ctr(aes)", 391 .base.cra_driver_name = "ctr-aes-neonbs", 392 .base.cra_priority = 250 - 1, 393 .base.cra_blocksize = 1, 394 .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx), 395 .base.cra_module = THIS_MODULE, 396 397 .min_keysize = AES_MIN_KEY_SIZE, 398 .max_keysize = AES_MAX_KEY_SIZE, 399 .chunksize = AES_BLOCK_SIZE, 400 .walksize = 8 * AES_BLOCK_SIZE, 401 .ivsize = AES_BLOCK_SIZE, 402 .setkey = aesbs_ctr_setkey_sync, 403 .encrypt = ctr_encrypt_sync, 404 .decrypt = ctr_encrypt_sync, 405 }, { 406 .base.cra_name = "__xts(aes)", 407 .base.cra_driver_name = "__xts-aes-neonbs", 408 .base.cra_priority = 250, 409 .base.cra_blocksize = AES_BLOCK_SIZE, 410 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 411 .base.cra_module = THIS_MODULE, 412 .base.cra_flags = CRYPTO_ALG_INTERNAL, 413 414 .min_keysize = 2 * AES_MIN_KEY_SIZE, 415 .max_keysize = 2 * AES_MAX_KEY_SIZE, 416 .walksize = 8 * AES_BLOCK_SIZE, 417 .ivsize = AES_BLOCK_SIZE, 418 .setkey = aesbs_xts_setkey, 419 .encrypt = xts_encrypt, 420 .decrypt = xts_decrypt, 421 } }; 422 423 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)]; 424 425 static void aes_exit(void) 426 { 427 int i; 428 429 for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++) 430 if (aes_simd_algs[i]) 431 simd_skcipher_free(aes_simd_algs[i]); 432 433 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 434 } 435 436 static int __init aes_init(void) 437 { 438 struct simd_skcipher_alg *simd; 439 const char *basename; 440 const char *algname; 441 const char *drvname; 442 int err; 443 int i; 444 445 if (!cpu_have_named_feature(ASIMD)) 446 return -ENODEV; 447 448 err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 449 if (err) 450 return err; 451 452 for (i = 0; i < ARRAY_SIZE(aes_algs); i++) { 453 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL)) 454 continue; 455 456 algname = aes_algs[i].base.cra_name + 2; 457 drvname = aes_algs[i].base.cra_driver_name + 2; 458 basename = aes_algs[i].base.cra_driver_name; 459 simd = simd_skcipher_create_compat(algname, drvname, basename); 460 err = PTR_ERR(simd); 461 if (IS_ERR(simd)) 462 goto unregister_simds; 463 464 aes_simd_algs[i] = simd; 465 } 466 return 0; 467 468 unregister_simds: 469 aes_exit(); 470 return err; 471 } 472 473 module_init(aes_init); 474 module_exit(aes_exit); 475