1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/ctr.h> 12 #include <crypto/internal/simd.h> 13 #include <crypto/internal/skcipher.h> 14 #include <crypto/scatterwalk.h> 15 #include <crypto/xts.h> 16 #include <linux/module.h> 17 18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 19 MODULE_LICENSE("GPL v2"); 20 21 MODULE_ALIAS_CRYPTO("ecb(aes)"); 22 MODULE_ALIAS_CRYPTO("cbc(aes)"); 23 MODULE_ALIAS_CRYPTO("ctr(aes)"); 24 MODULE_ALIAS_CRYPTO("xts(aes)"); 25 26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 27 28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 29 int rounds, int blocks); 30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks, u8 iv[]); 35 36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[], u8 final[]); 38 39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[]); 41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[]); 43 44 /* borrowed from aes-neon-blk.ko */ 45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 46 int rounds, int blocks); 47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 48 int rounds, int blocks, u8 iv[]); 49 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[], 50 u32 const rk1[], int rounds, int bytes, 51 u32 const rk2[], u8 iv[], int first); 52 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[], 53 u32 const rk1[], int rounds, int bytes, 54 u32 const rk2[], u8 iv[], int first); 55 56 struct aesbs_ctx { 57 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 58 int rounds; 59 } __aligned(AES_BLOCK_SIZE); 60 61 struct aesbs_cbc_ctx { 62 struct aesbs_ctx key; 63 u32 enc[AES_MAX_KEYLENGTH_U32]; 64 }; 65 66 struct aesbs_xts_ctx { 67 struct aesbs_ctx key; 68 u32 twkey[AES_MAX_KEYLENGTH_U32]; 69 struct crypto_aes_ctx cts; 70 }; 71 72 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 73 unsigned int key_len) 74 { 75 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 76 struct crypto_aes_ctx rk; 77 int err; 78 79 err = aes_expandkey(&rk, in_key, key_len); 80 if (err) 81 return err; 82 83 ctx->rounds = 6 + key_len / 4; 84 85 kernel_neon_begin(); 86 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 87 kernel_neon_end(); 88 89 return 0; 90 } 91 92 static int __ecb_crypt(struct skcipher_request *req, 93 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 94 int rounds, int blocks)) 95 { 96 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 97 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 98 struct skcipher_walk walk; 99 int err; 100 101 err = skcipher_walk_virt(&walk, req, false); 102 103 while (walk.nbytes >= AES_BLOCK_SIZE) { 104 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 105 106 if (walk.nbytes < walk.total) 107 blocks = round_down(blocks, 108 walk.stride / AES_BLOCK_SIZE); 109 110 kernel_neon_begin(); 111 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 112 ctx->rounds, blocks); 113 kernel_neon_end(); 114 err = skcipher_walk_done(&walk, 115 walk.nbytes - blocks * AES_BLOCK_SIZE); 116 } 117 118 return err; 119 } 120 121 static int ecb_encrypt(struct skcipher_request *req) 122 { 123 return __ecb_crypt(req, aesbs_ecb_encrypt); 124 } 125 126 static int ecb_decrypt(struct skcipher_request *req) 127 { 128 return __ecb_crypt(req, aesbs_ecb_decrypt); 129 } 130 131 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 132 unsigned int key_len) 133 { 134 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 135 struct crypto_aes_ctx rk; 136 int err; 137 138 err = aes_expandkey(&rk, in_key, key_len); 139 if (err) 140 return err; 141 142 ctx->key.rounds = 6 + key_len / 4; 143 144 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 145 146 kernel_neon_begin(); 147 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 148 kernel_neon_end(); 149 memzero_explicit(&rk, sizeof(rk)); 150 151 return 0; 152 } 153 154 static int cbc_encrypt(struct skcipher_request *req) 155 { 156 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 157 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 158 struct skcipher_walk walk; 159 int err; 160 161 err = skcipher_walk_virt(&walk, req, false); 162 163 while (walk.nbytes >= AES_BLOCK_SIZE) { 164 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 165 166 /* fall back to the non-bitsliced NEON implementation */ 167 kernel_neon_begin(); 168 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 169 ctx->enc, ctx->key.rounds, blocks, 170 walk.iv); 171 kernel_neon_end(); 172 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 173 } 174 return err; 175 } 176 177 static int cbc_decrypt(struct skcipher_request *req) 178 { 179 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 180 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 181 struct skcipher_walk walk; 182 int err; 183 184 err = skcipher_walk_virt(&walk, req, false); 185 186 while (walk.nbytes >= AES_BLOCK_SIZE) { 187 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 188 189 if (walk.nbytes < walk.total) 190 blocks = round_down(blocks, 191 walk.stride / AES_BLOCK_SIZE); 192 193 kernel_neon_begin(); 194 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 195 ctx->key.rk, ctx->key.rounds, blocks, 196 walk.iv); 197 kernel_neon_end(); 198 err = skcipher_walk_done(&walk, 199 walk.nbytes - blocks * AES_BLOCK_SIZE); 200 } 201 202 return err; 203 } 204 205 static int ctr_encrypt(struct skcipher_request *req) 206 { 207 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 208 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 209 struct skcipher_walk walk; 210 u8 buf[AES_BLOCK_SIZE]; 211 int err; 212 213 err = skcipher_walk_virt(&walk, req, false); 214 215 while (walk.nbytes > 0) { 216 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 217 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL; 218 219 if (walk.nbytes < walk.total) { 220 blocks = round_down(blocks, 221 walk.stride / AES_BLOCK_SIZE); 222 final = NULL; 223 } 224 225 kernel_neon_begin(); 226 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 227 ctx->rk, ctx->rounds, blocks, walk.iv, final); 228 kernel_neon_end(); 229 230 if (final) { 231 u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE; 232 u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE; 233 234 crypto_xor_cpy(dst, src, final, 235 walk.total % AES_BLOCK_SIZE); 236 237 err = skcipher_walk_done(&walk, 0); 238 break; 239 } 240 err = skcipher_walk_done(&walk, 241 walk.nbytes - blocks * AES_BLOCK_SIZE); 242 } 243 return err; 244 } 245 246 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 247 unsigned int key_len) 248 { 249 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 250 struct crypto_aes_ctx rk; 251 int err; 252 253 err = xts_verify_key(tfm, in_key, key_len); 254 if (err) 255 return err; 256 257 key_len /= 2; 258 err = aes_expandkey(&ctx->cts, in_key, key_len); 259 if (err) 260 return err; 261 262 err = aes_expandkey(&rk, in_key + key_len, key_len); 263 if (err) 264 return err; 265 266 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 267 268 return aesbs_setkey(tfm, in_key, key_len); 269 } 270 271 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 272 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 273 int rounds, int blocks, u8 iv[])) 274 { 275 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 276 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 277 int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); 278 struct scatterlist sg_src[2], sg_dst[2]; 279 struct skcipher_request subreq; 280 struct scatterlist *src, *dst; 281 struct skcipher_walk walk; 282 int nbytes, err; 283 int first = 1; 284 u8 *out, *in; 285 286 if (req->cryptlen < AES_BLOCK_SIZE) 287 return -EINVAL; 288 289 /* ensure that the cts tail is covered by a single step */ 290 if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { 291 int xts_blocks = DIV_ROUND_UP(req->cryptlen, 292 AES_BLOCK_SIZE) - 2; 293 294 skcipher_request_set_tfm(&subreq, tfm); 295 skcipher_request_set_callback(&subreq, 296 skcipher_request_flags(req), 297 NULL, NULL); 298 skcipher_request_set_crypt(&subreq, req->src, req->dst, 299 xts_blocks * AES_BLOCK_SIZE, 300 req->iv); 301 req = &subreq; 302 } else { 303 tail = 0; 304 } 305 306 err = skcipher_walk_virt(&walk, req, false); 307 if (err) 308 return err; 309 310 while (walk.nbytes >= AES_BLOCK_SIZE) { 311 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 312 313 if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE) 314 blocks = round_down(blocks, 315 walk.stride / AES_BLOCK_SIZE); 316 317 out = walk.dst.virt.addr; 318 in = walk.src.virt.addr; 319 nbytes = walk.nbytes; 320 321 kernel_neon_begin(); 322 if (likely(blocks > 6)) { /* plain NEON is faster otherwise */ 323 if (first) 324 neon_aes_ecb_encrypt(walk.iv, walk.iv, 325 ctx->twkey, 326 ctx->key.rounds, 1); 327 first = 0; 328 329 fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, 330 walk.iv); 331 332 out += blocks * AES_BLOCK_SIZE; 333 in += blocks * AES_BLOCK_SIZE; 334 nbytes -= blocks * AES_BLOCK_SIZE; 335 } 336 337 if (walk.nbytes == walk.total && nbytes > 0) 338 goto xts_tail; 339 340 kernel_neon_end(); 341 err = skcipher_walk_done(&walk, nbytes); 342 } 343 344 if (err || likely(!tail)) 345 return err; 346 347 /* handle ciphertext stealing */ 348 dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen); 349 if (req->dst != req->src) 350 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen); 351 352 skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, 353 req->iv); 354 355 err = skcipher_walk_virt(&walk, req, false); 356 if (err) 357 return err; 358 359 out = walk.dst.virt.addr; 360 in = walk.src.virt.addr; 361 nbytes = walk.nbytes; 362 363 kernel_neon_begin(); 364 xts_tail: 365 if (encrypt) 366 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds, 367 nbytes, ctx->twkey, walk.iv, first ?: 2); 368 else 369 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds, 370 nbytes, ctx->twkey, walk.iv, first ?: 2); 371 kernel_neon_end(); 372 373 return skcipher_walk_done(&walk, 0); 374 } 375 376 static int xts_encrypt(struct skcipher_request *req) 377 { 378 return __xts_crypt(req, true, aesbs_xts_encrypt); 379 } 380 381 static int xts_decrypt(struct skcipher_request *req) 382 { 383 return __xts_crypt(req, false, aesbs_xts_decrypt); 384 } 385 386 static struct skcipher_alg aes_algs[] = { { 387 .base.cra_name = "ecb(aes)", 388 .base.cra_driver_name = "ecb-aes-neonbs", 389 .base.cra_priority = 250, 390 .base.cra_blocksize = AES_BLOCK_SIZE, 391 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 392 .base.cra_module = THIS_MODULE, 393 394 .min_keysize = AES_MIN_KEY_SIZE, 395 .max_keysize = AES_MAX_KEY_SIZE, 396 .walksize = 8 * AES_BLOCK_SIZE, 397 .setkey = aesbs_setkey, 398 .encrypt = ecb_encrypt, 399 .decrypt = ecb_decrypt, 400 }, { 401 .base.cra_name = "cbc(aes)", 402 .base.cra_driver_name = "cbc-aes-neonbs", 403 .base.cra_priority = 250, 404 .base.cra_blocksize = AES_BLOCK_SIZE, 405 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 406 .base.cra_module = THIS_MODULE, 407 408 .min_keysize = AES_MIN_KEY_SIZE, 409 .max_keysize = AES_MAX_KEY_SIZE, 410 .walksize = 8 * AES_BLOCK_SIZE, 411 .ivsize = AES_BLOCK_SIZE, 412 .setkey = aesbs_cbc_setkey, 413 .encrypt = cbc_encrypt, 414 .decrypt = cbc_decrypt, 415 }, { 416 .base.cra_name = "ctr(aes)", 417 .base.cra_driver_name = "ctr-aes-neonbs", 418 .base.cra_priority = 250, 419 .base.cra_blocksize = 1, 420 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 421 .base.cra_module = THIS_MODULE, 422 423 .min_keysize = AES_MIN_KEY_SIZE, 424 .max_keysize = AES_MAX_KEY_SIZE, 425 .chunksize = AES_BLOCK_SIZE, 426 .walksize = 8 * AES_BLOCK_SIZE, 427 .ivsize = AES_BLOCK_SIZE, 428 .setkey = aesbs_setkey, 429 .encrypt = ctr_encrypt, 430 .decrypt = ctr_encrypt, 431 }, { 432 .base.cra_name = "xts(aes)", 433 .base.cra_driver_name = "xts-aes-neonbs", 434 .base.cra_priority = 250, 435 .base.cra_blocksize = AES_BLOCK_SIZE, 436 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 437 .base.cra_module = THIS_MODULE, 438 439 .min_keysize = 2 * AES_MIN_KEY_SIZE, 440 .max_keysize = 2 * AES_MAX_KEY_SIZE, 441 .walksize = 8 * AES_BLOCK_SIZE, 442 .ivsize = AES_BLOCK_SIZE, 443 .setkey = aesbs_xts_setkey, 444 .encrypt = xts_encrypt, 445 .decrypt = xts_decrypt, 446 } }; 447 448 static void aes_exit(void) 449 { 450 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 451 } 452 453 static int __init aes_init(void) 454 { 455 if (!cpu_have_named_feature(ASIMD)) 456 return -ENODEV; 457 458 return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 459 } 460 461 module_init(aes_init); 462 module_exit(aes_exit); 463