1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/internal/simd.h> 12 #include <crypto/internal/skcipher.h> 13 #include <crypto/xts.h> 14 #include <linux/module.h> 15 16 #include "aes-ctr-fallback.h" 17 18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 19 MODULE_LICENSE("GPL v2"); 20 21 MODULE_ALIAS_CRYPTO("ecb(aes)"); 22 MODULE_ALIAS_CRYPTO("cbc(aes)"); 23 MODULE_ALIAS_CRYPTO("ctr(aes)"); 24 MODULE_ALIAS_CRYPTO("xts(aes)"); 25 26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 27 28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 29 int rounds, int blocks); 30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks, u8 iv[]); 35 36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[], u8 final[]); 38 39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[]); 41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[]); 43 44 /* borrowed from aes-neon-blk.ko */ 45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 46 int rounds, int blocks); 47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 48 int rounds, int blocks, u8 iv[]); 49 50 struct aesbs_ctx { 51 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 52 int rounds; 53 } __aligned(AES_BLOCK_SIZE); 54 55 struct aesbs_cbc_ctx { 56 struct aesbs_ctx key; 57 u32 enc[AES_MAX_KEYLENGTH_U32]; 58 }; 59 60 struct aesbs_ctr_ctx { 61 struct aesbs_ctx key; /* must be first member */ 62 struct crypto_aes_ctx fallback; 63 }; 64 65 struct aesbs_xts_ctx { 66 struct aesbs_ctx key; 67 u32 twkey[AES_MAX_KEYLENGTH_U32]; 68 }; 69 70 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 71 unsigned int key_len) 72 { 73 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 74 struct crypto_aes_ctx rk; 75 int err; 76 77 err = crypto_aes_expand_key(&rk, in_key, key_len); 78 if (err) 79 return err; 80 81 ctx->rounds = 6 + key_len / 4; 82 83 kernel_neon_begin(); 84 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 85 kernel_neon_end(); 86 87 return 0; 88 } 89 90 static int __ecb_crypt(struct skcipher_request *req, 91 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 92 int rounds, int blocks)) 93 { 94 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 95 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 96 struct skcipher_walk walk; 97 int err; 98 99 err = skcipher_walk_virt(&walk, req, false); 100 101 while (walk.nbytes >= AES_BLOCK_SIZE) { 102 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 103 104 if (walk.nbytes < walk.total) 105 blocks = round_down(blocks, 106 walk.stride / AES_BLOCK_SIZE); 107 108 kernel_neon_begin(); 109 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 110 ctx->rounds, blocks); 111 kernel_neon_end(); 112 err = skcipher_walk_done(&walk, 113 walk.nbytes - blocks * AES_BLOCK_SIZE); 114 } 115 116 return err; 117 } 118 119 static int ecb_encrypt(struct skcipher_request *req) 120 { 121 return __ecb_crypt(req, aesbs_ecb_encrypt); 122 } 123 124 static int ecb_decrypt(struct skcipher_request *req) 125 { 126 return __ecb_crypt(req, aesbs_ecb_decrypt); 127 } 128 129 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 130 unsigned int key_len) 131 { 132 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 133 struct crypto_aes_ctx rk; 134 int err; 135 136 err = crypto_aes_expand_key(&rk, in_key, key_len); 137 if (err) 138 return err; 139 140 ctx->key.rounds = 6 + key_len / 4; 141 142 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 143 144 kernel_neon_begin(); 145 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 146 kernel_neon_end(); 147 148 return 0; 149 } 150 151 static int cbc_encrypt(struct skcipher_request *req) 152 { 153 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 154 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 155 struct skcipher_walk walk; 156 int err; 157 158 err = skcipher_walk_virt(&walk, req, false); 159 160 while (walk.nbytes >= AES_BLOCK_SIZE) { 161 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 162 163 /* fall back to the non-bitsliced NEON implementation */ 164 kernel_neon_begin(); 165 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 166 ctx->enc, ctx->key.rounds, blocks, 167 walk.iv); 168 kernel_neon_end(); 169 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 170 } 171 return err; 172 } 173 174 static int cbc_decrypt(struct skcipher_request *req) 175 { 176 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 177 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 178 struct skcipher_walk walk; 179 int err; 180 181 err = skcipher_walk_virt(&walk, req, false); 182 183 while (walk.nbytes >= AES_BLOCK_SIZE) { 184 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 185 186 if (walk.nbytes < walk.total) 187 blocks = round_down(blocks, 188 walk.stride / AES_BLOCK_SIZE); 189 190 kernel_neon_begin(); 191 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 192 ctx->key.rk, ctx->key.rounds, blocks, 193 walk.iv); 194 kernel_neon_end(); 195 err = skcipher_walk_done(&walk, 196 walk.nbytes - blocks * AES_BLOCK_SIZE); 197 } 198 199 return err; 200 } 201 202 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key, 203 unsigned int key_len) 204 { 205 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 206 int err; 207 208 err = crypto_aes_expand_key(&ctx->fallback, in_key, key_len); 209 if (err) 210 return err; 211 212 ctx->key.rounds = 6 + key_len / 4; 213 214 kernel_neon_begin(); 215 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 216 kernel_neon_end(); 217 218 return 0; 219 } 220 221 static int ctr_encrypt(struct skcipher_request *req) 222 { 223 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 224 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 225 struct skcipher_walk walk; 226 u8 buf[AES_BLOCK_SIZE]; 227 int err; 228 229 err = skcipher_walk_virt(&walk, req, false); 230 231 while (walk.nbytes > 0) { 232 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 233 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL; 234 235 if (walk.nbytes < walk.total) { 236 blocks = round_down(blocks, 237 walk.stride / AES_BLOCK_SIZE); 238 final = NULL; 239 } 240 241 kernel_neon_begin(); 242 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 243 ctx->rk, ctx->rounds, blocks, walk.iv, final); 244 kernel_neon_end(); 245 246 if (final) { 247 u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE; 248 u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE; 249 250 crypto_xor_cpy(dst, src, final, 251 walk.total % AES_BLOCK_SIZE); 252 253 err = skcipher_walk_done(&walk, 0); 254 break; 255 } 256 err = skcipher_walk_done(&walk, 257 walk.nbytes - blocks * AES_BLOCK_SIZE); 258 } 259 return err; 260 } 261 262 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 263 unsigned int key_len) 264 { 265 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 266 struct crypto_aes_ctx rk; 267 int err; 268 269 err = xts_verify_key(tfm, in_key, key_len); 270 if (err) 271 return err; 272 273 key_len /= 2; 274 err = crypto_aes_expand_key(&rk, in_key + key_len, key_len); 275 if (err) 276 return err; 277 278 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 279 280 return aesbs_setkey(tfm, in_key, key_len); 281 } 282 283 static int ctr_encrypt_sync(struct skcipher_request *req) 284 { 285 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 286 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 287 288 if (!crypto_simd_usable()) 289 return aes_ctr_encrypt_fallback(&ctx->fallback, req); 290 291 return ctr_encrypt(req); 292 } 293 294 static int __xts_crypt(struct skcipher_request *req, 295 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 296 int rounds, int blocks, u8 iv[])) 297 { 298 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 299 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 300 struct skcipher_walk walk; 301 int err; 302 303 err = skcipher_walk_virt(&walk, req, false); 304 if (err) 305 return err; 306 307 kernel_neon_begin(); 308 neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1); 309 kernel_neon_end(); 310 311 while (walk.nbytes >= AES_BLOCK_SIZE) { 312 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 313 314 if (walk.nbytes < walk.total) 315 blocks = round_down(blocks, 316 walk.stride / AES_BLOCK_SIZE); 317 318 kernel_neon_begin(); 319 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, 320 ctx->key.rounds, blocks, walk.iv); 321 kernel_neon_end(); 322 err = skcipher_walk_done(&walk, 323 walk.nbytes - blocks * AES_BLOCK_SIZE); 324 } 325 return err; 326 } 327 328 static int xts_encrypt(struct skcipher_request *req) 329 { 330 return __xts_crypt(req, aesbs_xts_encrypt); 331 } 332 333 static int xts_decrypt(struct skcipher_request *req) 334 { 335 return __xts_crypt(req, aesbs_xts_decrypt); 336 } 337 338 static struct skcipher_alg aes_algs[] = { { 339 .base.cra_name = "__ecb(aes)", 340 .base.cra_driver_name = "__ecb-aes-neonbs", 341 .base.cra_priority = 250, 342 .base.cra_blocksize = AES_BLOCK_SIZE, 343 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 344 .base.cra_module = THIS_MODULE, 345 .base.cra_flags = CRYPTO_ALG_INTERNAL, 346 347 .min_keysize = AES_MIN_KEY_SIZE, 348 .max_keysize = AES_MAX_KEY_SIZE, 349 .walksize = 8 * AES_BLOCK_SIZE, 350 .setkey = aesbs_setkey, 351 .encrypt = ecb_encrypt, 352 .decrypt = ecb_decrypt, 353 }, { 354 .base.cra_name = "__cbc(aes)", 355 .base.cra_driver_name = "__cbc-aes-neonbs", 356 .base.cra_priority = 250, 357 .base.cra_blocksize = AES_BLOCK_SIZE, 358 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 359 .base.cra_module = THIS_MODULE, 360 .base.cra_flags = CRYPTO_ALG_INTERNAL, 361 362 .min_keysize = AES_MIN_KEY_SIZE, 363 .max_keysize = AES_MAX_KEY_SIZE, 364 .walksize = 8 * AES_BLOCK_SIZE, 365 .ivsize = AES_BLOCK_SIZE, 366 .setkey = aesbs_cbc_setkey, 367 .encrypt = cbc_encrypt, 368 .decrypt = cbc_decrypt, 369 }, { 370 .base.cra_name = "__ctr(aes)", 371 .base.cra_driver_name = "__ctr-aes-neonbs", 372 .base.cra_priority = 250, 373 .base.cra_blocksize = 1, 374 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 375 .base.cra_module = THIS_MODULE, 376 .base.cra_flags = CRYPTO_ALG_INTERNAL, 377 378 .min_keysize = AES_MIN_KEY_SIZE, 379 .max_keysize = AES_MAX_KEY_SIZE, 380 .chunksize = AES_BLOCK_SIZE, 381 .walksize = 8 * AES_BLOCK_SIZE, 382 .ivsize = AES_BLOCK_SIZE, 383 .setkey = aesbs_setkey, 384 .encrypt = ctr_encrypt, 385 .decrypt = ctr_encrypt, 386 }, { 387 .base.cra_name = "ctr(aes)", 388 .base.cra_driver_name = "ctr-aes-neonbs", 389 .base.cra_priority = 250 - 1, 390 .base.cra_blocksize = 1, 391 .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx), 392 .base.cra_module = THIS_MODULE, 393 394 .min_keysize = AES_MIN_KEY_SIZE, 395 .max_keysize = AES_MAX_KEY_SIZE, 396 .chunksize = AES_BLOCK_SIZE, 397 .walksize = 8 * AES_BLOCK_SIZE, 398 .ivsize = AES_BLOCK_SIZE, 399 .setkey = aesbs_ctr_setkey_sync, 400 .encrypt = ctr_encrypt_sync, 401 .decrypt = ctr_encrypt_sync, 402 }, { 403 .base.cra_name = "__xts(aes)", 404 .base.cra_driver_name = "__xts-aes-neonbs", 405 .base.cra_priority = 250, 406 .base.cra_blocksize = AES_BLOCK_SIZE, 407 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 408 .base.cra_module = THIS_MODULE, 409 .base.cra_flags = CRYPTO_ALG_INTERNAL, 410 411 .min_keysize = 2 * AES_MIN_KEY_SIZE, 412 .max_keysize = 2 * AES_MAX_KEY_SIZE, 413 .walksize = 8 * AES_BLOCK_SIZE, 414 .ivsize = AES_BLOCK_SIZE, 415 .setkey = aesbs_xts_setkey, 416 .encrypt = xts_encrypt, 417 .decrypt = xts_decrypt, 418 } }; 419 420 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)]; 421 422 static void aes_exit(void) 423 { 424 int i; 425 426 for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++) 427 if (aes_simd_algs[i]) 428 simd_skcipher_free(aes_simd_algs[i]); 429 430 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 431 } 432 433 static int __init aes_init(void) 434 { 435 struct simd_skcipher_alg *simd; 436 const char *basename; 437 const char *algname; 438 const char *drvname; 439 int err; 440 int i; 441 442 if (!cpu_have_named_feature(ASIMD)) 443 return -ENODEV; 444 445 err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 446 if (err) 447 return err; 448 449 for (i = 0; i < ARRAY_SIZE(aes_algs); i++) { 450 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL)) 451 continue; 452 453 algname = aes_algs[i].base.cra_name + 2; 454 drvname = aes_algs[i].base.cra_driver_name + 2; 455 basename = aes_algs[i].base.cra_driver_name; 456 simd = simd_skcipher_create_compat(algname, drvname, basename); 457 err = PTR_ERR(simd); 458 if (IS_ERR(simd)) 459 goto unregister_simds; 460 461 aes_simd_algs[i] = simd; 462 } 463 return 0; 464 465 unregister_simds: 466 aes_exit(); 467 return err; 468 } 469 470 module_init(aes_init); 471 module_exit(aes_exit); 472