1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/ctr.h> 12 #include <crypto/internal/simd.h> 13 #include <crypto/internal/skcipher.h> 14 #include <crypto/scatterwalk.h> 15 #include <crypto/xts.h> 16 #include <linux/module.h> 17 18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 19 MODULE_LICENSE("GPL v2"); 20 21 MODULE_ALIAS_CRYPTO("ecb(aes)"); 22 MODULE_ALIAS_CRYPTO("cbc(aes)"); 23 MODULE_ALIAS_CRYPTO("ctr(aes)"); 24 MODULE_ALIAS_CRYPTO("xts(aes)"); 25 26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 27 28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 29 int rounds, int blocks); 30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks, u8 iv[]); 35 36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[], u8 final[]); 38 39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[]); 41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[]); 43 44 /* borrowed from aes-neon-blk.ko */ 45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 46 int rounds, int blocks); 47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 48 int rounds, int blocks, u8 iv[]); 49 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[], 50 u32 const rk1[], int rounds, int bytes, 51 u32 const rk2[], u8 iv[], int first); 52 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[], 53 u32 const rk1[], int rounds, int bytes, 54 u32 const rk2[], u8 iv[], int first); 55 56 struct aesbs_ctx { 57 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 58 int rounds; 59 } __aligned(AES_BLOCK_SIZE); 60 61 struct aesbs_cbc_ctx { 62 struct aesbs_ctx key; 63 u32 enc[AES_MAX_KEYLENGTH_U32]; 64 }; 65 66 struct aesbs_ctr_ctx { 67 struct aesbs_ctx key; /* must be first member */ 68 struct crypto_aes_ctx fallback; 69 }; 70 71 struct aesbs_xts_ctx { 72 struct aesbs_ctx key; 73 u32 twkey[AES_MAX_KEYLENGTH_U32]; 74 struct crypto_aes_ctx cts; 75 }; 76 77 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 78 unsigned int key_len) 79 { 80 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 81 struct crypto_aes_ctx rk; 82 int err; 83 84 err = aes_expandkey(&rk, in_key, key_len); 85 if (err) 86 return err; 87 88 ctx->rounds = 6 + key_len / 4; 89 90 kernel_neon_begin(); 91 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 92 kernel_neon_end(); 93 94 return 0; 95 } 96 97 static int __ecb_crypt(struct skcipher_request *req, 98 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 99 int rounds, int blocks)) 100 { 101 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 102 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 103 struct skcipher_walk walk; 104 int err; 105 106 err = skcipher_walk_virt(&walk, req, false); 107 108 while (walk.nbytes >= AES_BLOCK_SIZE) { 109 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 110 111 if (walk.nbytes < walk.total) 112 blocks = round_down(blocks, 113 walk.stride / AES_BLOCK_SIZE); 114 115 kernel_neon_begin(); 116 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 117 ctx->rounds, blocks); 118 kernel_neon_end(); 119 err = skcipher_walk_done(&walk, 120 walk.nbytes - blocks * AES_BLOCK_SIZE); 121 } 122 123 return err; 124 } 125 126 static int ecb_encrypt(struct skcipher_request *req) 127 { 128 return __ecb_crypt(req, aesbs_ecb_encrypt); 129 } 130 131 static int ecb_decrypt(struct skcipher_request *req) 132 { 133 return __ecb_crypt(req, aesbs_ecb_decrypt); 134 } 135 136 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 137 unsigned int key_len) 138 { 139 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 140 struct crypto_aes_ctx rk; 141 int err; 142 143 err = aes_expandkey(&rk, in_key, key_len); 144 if (err) 145 return err; 146 147 ctx->key.rounds = 6 + key_len / 4; 148 149 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 150 151 kernel_neon_begin(); 152 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 153 kernel_neon_end(); 154 memzero_explicit(&rk, sizeof(rk)); 155 156 return 0; 157 } 158 159 static int cbc_encrypt(struct skcipher_request *req) 160 { 161 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 162 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 163 struct skcipher_walk walk; 164 int err; 165 166 err = skcipher_walk_virt(&walk, req, false); 167 168 while (walk.nbytes >= AES_BLOCK_SIZE) { 169 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 170 171 /* fall back to the non-bitsliced NEON implementation */ 172 kernel_neon_begin(); 173 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 174 ctx->enc, ctx->key.rounds, blocks, 175 walk.iv); 176 kernel_neon_end(); 177 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 178 } 179 return err; 180 } 181 182 static int cbc_decrypt(struct skcipher_request *req) 183 { 184 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 185 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 186 struct skcipher_walk walk; 187 int err; 188 189 err = skcipher_walk_virt(&walk, req, false); 190 191 while (walk.nbytes >= AES_BLOCK_SIZE) { 192 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 193 194 if (walk.nbytes < walk.total) 195 blocks = round_down(blocks, 196 walk.stride / AES_BLOCK_SIZE); 197 198 kernel_neon_begin(); 199 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 200 ctx->key.rk, ctx->key.rounds, blocks, 201 walk.iv); 202 kernel_neon_end(); 203 err = skcipher_walk_done(&walk, 204 walk.nbytes - blocks * AES_BLOCK_SIZE); 205 } 206 207 return err; 208 } 209 210 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key, 211 unsigned int key_len) 212 { 213 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 214 int err; 215 216 err = aes_expandkey(&ctx->fallback, in_key, key_len); 217 if (err) 218 return err; 219 220 ctx->key.rounds = 6 + key_len / 4; 221 222 kernel_neon_begin(); 223 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 224 kernel_neon_end(); 225 226 return 0; 227 } 228 229 static int ctr_encrypt(struct skcipher_request *req) 230 { 231 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 232 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 233 struct skcipher_walk walk; 234 u8 buf[AES_BLOCK_SIZE]; 235 int err; 236 237 err = skcipher_walk_virt(&walk, req, false); 238 239 while (walk.nbytes > 0) { 240 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 241 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL; 242 243 if (walk.nbytes < walk.total) { 244 blocks = round_down(blocks, 245 walk.stride / AES_BLOCK_SIZE); 246 final = NULL; 247 } 248 249 kernel_neon_begin(); 250 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 251 ctx->rk, ctx->rounds, blocks, walk.iv, final); 252 kernel_neon_end(); 253 254 if (final) { 255 u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE; 256 u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE; 257 258 crypto_xor_cpy(dst, src, final, 259 walk.total % AES_BLOCK_SIZE); 260 261 err = skcipher_walk_done(&walk, 0); 262 break; 263 } 264 err = skcipher_walk_done(&walk, 265 walk.nbytes - blocks * AES_BLOCK_SIZE); 266 } 267 return err; 268 } 269 270 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 271 unsigned int key_len) 272 { 273 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 274 struct crypto_aes_ctx rk; 275 int err; 276 277 err = xts_verify_key(tfm, in_key, key_len); 278 if (err) 279 return err; 280 281 key_len /= 2; 282 err = aes_expandkey(&ctx->cts, in_key, key_len); 283 if (err) 284 return err; 285 286 err = aes_expandkey(&rk, in_key + key_len, key_len); 287 if (err) 288 return err; 289 290 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 291 292 return aesbs_setkey(tfm, in_key, key_len); 293 } 294 295 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst) 296 { 297 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 298 unsigned long flags; 299 300 /* 301 * Temporarily disable interrupts to avoid races where 302 * cachelines are evicted when the CPU is interrupted 303 * to do something else. 304 */ 305 local_irq_save(flags); 306 aes_encrypt(&ctx->fallback, dst, src); 307 local_irq_restore(flags); 308 } 309 310 static int ctr_encrypt_sync(struct skcipher_request *req) 311 { 312 if (!crypto_simd_usable()) 313 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one); 314 315 return ctr_encrypt(req); 316 } 317 318 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 319 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 320 int rounds, int blocks, u8 iv[])) 321 { 322 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 323 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 324 int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); 325 struct scatterlist sg_src[2], sg_dst[2]; 326 struct skcipher_request subreq; 327 struct scatterlist *src, *dst; 328 struct skcipher_walk walk; 329 int nbytes, err; 330 int first = 1; 331 u8 *out, *in; 332 333 if (req->cryptlen < AES_BLOCK_SIZE) 334 return -EINVAL; 335 336 /* ensure that the cts tail is covered by a single step */ 337 if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { 338 int xts_blocks = DIV_ROUND_UP(req->cryptlen, 339 AES_BLOCK_SIZE) - 2; 340 341 skcipher_request_set_tfm(&subreq, tfm); 342 skcipher_request_set_callback(&subreq, 343 skcipher_request_flags(req), 344 NULL, NULL); 345 skcipher_request_set_crypt(&subreq, req->src, req->dst, 346 xts_blocks * AES_BLOCK_SIZE, 347 req->iv); 348 req = &subreq; 349 } else { 350 tail = 0; 351 } 352 353 err = skcipher_walk_virt(&walk, req, false); 354 if (err) 355 return err; 356 357 while (walk.nbytes >= AES_BLOCK_SIZE) { 358 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 359 360 if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE) 361 blocks = round_down(blocks, 362 walk.stride / AES_BLOCK_SIZE); 363 364 out = walk.dst.virt.addr; 365 in = walk.src.virt.addr; 366 nbytes = walk.nbytes; 367 368 kernel_neon_begin(); 369 if (likely(blocks > 6)) { /* plain NEON is faster otherwise */ 370 if (first) 371 neon_aes_ecb_encrypt(walk.iv, walk.iv, 372 ctx->twkey, 373 ctx->key.rounds, 1); 374 first = 0; 375 376 fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, 377 walk.iv); 378 379 out += blocks * AES_BLOCK_SIZE; 380 in += blocks * AES_BLOCK_SIZE; 381 nbytes -= blocks * AES_BLOCK_SIZE; 382 } 383 384 if (walk.nbytes == walk.total && nbytes > 0) 385 goto xts_tail; 386 387 kernel_neon_end(); 388 err = skcipher_walk_done(&walk, nbytes); 389 } 390 391 if (err || likely(!tail)) 392 return err; 393 394 /* handle ciphertext stealing */ 395 dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen); 396 if (req->dst != req->src) 397 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen); 398 399 skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, 400 req->iv); 401 402 err = skcipher_walk_virt(&walk, req, false); 403 if (err) 404 return err; 405 406 out = walk.dst.virt.addr; 407 in = walk.src.virt.addr; 408 nbytes = walk.nbytes; 409 410 kernel_neon_begin(); 411 xts_tail: 412 if (encrypt) 413 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds, 414 nbytes, ctx->twkey, walk.iv, first ?: 2); 415 else 416 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds, 417 nbytes, ctx->twkey, walk.iv, first ?: 2); 418 kernel_neon_end(); 419 420 return skcipher_walk_done(&walk, 0); 421 } 422 423 static int xts_encrypt(struct skcipher_request *req) 424 { 425 return __xts_crypt(req, true, aesbs_xts_encrypt); 426 } 427 428 static int xts_decrypt(struct skcipher_request *req) 429 { 430 return __xts_crypt(req, false, aesbs_xts_decrypt); 431 } 432 433 static struct skcipher_alg aes_algs[] = { { 434 .base.cra_name = "__ecb(aes)", 435 .base.cra_driver_name = "__ecb-aes-neonbs", 436 .base.cra_priority = 250, 437 .base.cra_blocksize = AES_BLOCK_SIZE, 438 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 439 .base.cra_module = THIS_MODULE, 440 .base.cra_flags = CRYPTO_ALG_INTERNAL, 441 442 .min_keysize = AES_MIN_KEY_SIZE, 443 .max_keysize = AES_MAX_KEY_SIZE, 444 .walksize = 8 * AES_BLOCK_SIZE, 445 .setkey = aesbs_setkey, 446 .encrypt = ecb_encrypt, 447 .decrypt = ecb_decrypt, 448 }, { 449 .base.cra_name = "__cbc(aes)", 450 .base.cra_driver_name = "__cbc-aes-neonbs", 451 .base.cra_priority = 250, 452 .base.cra_blocksize = AES_BLOCK_SIZE, 453 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 454 .base.cra_module = THIS_MODULE, 455 .base.cra_flags = CRYPTO_ALG_INTERNAL, 456 457 .min_keysize = AES_MIN_KEY_SIZE, 458 .max_keysize = AES_MAX_KEY_SIZE, 459 .walksize = 8 * AES_BLOCK_SIZE, 460 .ivsize = AES_BLOCK_SIZE, 461 .setkey = aesbs_cbc_setkey, 462 .encrypt = cbc_encrypt, 463 .decrypt = cbc_decrypt, 464 }, { 465 .base.cra_name = "__ctr(aes)", 466 .base.cra_driver_name = "__ctr-aes-neonbs", 467 .base.cra_priority = 250, 468 .base.cra_blocksize = 1, 469 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 470 .base.cra_module = THIS_MODULE, 471 .base.cra_flags = CRYPTO_ALG_INTERNAL, 472 473 .min_keysize = AES_MIN_KEY_SIZE, 474 .max_keysize = AES_MAX_KEY_SIZE, 475 .chunksize = AES_BLOCK_SIZE, 476 .walksize = 8 * AES_BLOCK_SIZE, 477 .ivsize = AES_BLOCK_SIZE, 478 .setkey = aesbs_setkey, 479 .encrypt = ctr_encrypt, 480 .decrypt = ctr_encrypt, 481 }, { 482 .base.cra_name = "ctr(aes)", 483 .base.cra_driver_name = "ctr-aes-neonbs", 484 .base.cra_priority = 250 - 1, 485 .base.cra_blocksize = 1, 486 .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx), 487 .base.cra_module = THIS_MODULE, 488 489 .min_keysize = AES_MIN_KEY_SIZE, 490 .max_keysize = AES_MAX_KEY_SIZE, 491 .chunksize = AES_BLOCK_SIZE, 492 .walksize = 8 * AES_BLOCK_SIZE, 493 .ivsize = AES_BLOCK_SIZE, 494 .setkey = aesbs_ctr_setkey_sync, 495 .encrypt = ctr_encrypt_sync, 496 .decrypt = ctr_encrypt_sync, 497 }, { 498 .base.cra_name = "__xts(aes)", 499 .base.cra_driver_name = "__xts-aes-neonbs", 500 .base.cra_priority = 250, 501 .base.cra_blocksize = AES_BLOCK_SIZE, 502 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 503 .base.cra_module = THIS_MODULE, 504 .base.cra_flags = CRYPTO_ALG_INTERNAL, 505 506 .min_keysize = 2 * AES_MIN_KEY_SIZE, 507 .max_keysize = 2 * AES_MAX_KEY_SIZE, 508 .walksize = 8 * AES_BLOCK_SIZE, 509 .ivsize = AES_BLOCK_SIZE, 510 .setkey = aesbs_xts_setkey, 511 .encrypt = xts_encrypt, 512 .decrypt = xts_decrypt, 513 } }; 514 515 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)]; 516 517 static void aes_exit(void) 518 { 519 int i; 520 521 for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++) 522 if (aes_simd_algs[i]) 523 simd_skcipher_free(aes_simd_algs[i]); 524 525 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 526 } 527 528 static int __init aes_init(void) 529 { 530 struct simd_skcipher_alg *simd; 531 const char *basename; 532 const char *algname; 533 const char *drvname; 534 int err; 535 int i; 536 537 if (!cpu_have_named_feature(ASIMD)) 538 return -ENODEV; 539 540 err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 541 if (err) 542 return err; 543 544 for (i = 0; i < ARRAY_SIZE(aes_algs); i++) { 545 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL)) 546 continue; 547 548 algname = aes_algs[i].base.cra_name + 2; 549 drvname = aes_algs[i].base.cra_driver_name + 2; 550 basename = aes_algs[i].base.cra_driver_name; 551 simd = simd_skcipher_create_compat(algname, drvname, basename); 552 err = PTR_ERR(simd); 553 if (IS_ERR(simd)) 554 goto unregister_simds; 555 556 aes_simd_algs[i] = simd; 557 } 558 return 0; 559 560 unregister_simds: 561 aes_exit(); 562 return err; 563 } 564 565 module_init(aes_init); 566 module_exit(aes_exit); 567