1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/ctr.h> 12 #include <crypto/internal/simd.h> 13 #include <crypto/internal/skcipher.h> 14 #include <crypto/scatterwalk.h> 15 #include <crypto/xts.h> 16 #include <linux/module.h> 17 18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 19 MODULE_LICENSE("GPL v2"); 20 21 MODULE_ALIAS_CRYPTO("ecb(aes)"); 22 MODULE_ALIAS_CRYPTO("cbc(aes)"); 23 MODULE_ALIAS_CRYPTO("ctr(aes)"); 24 MODULE_ALIAS_CRYPTO("xts(aes)"); 25 26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 27 28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 29 int rounds, int blocks); 30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks, u8 iv[]); 35 36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[]); 38 39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[]); 41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[]); 43 44 /* borrowed from aes-neon-blk.ko */ 45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 46 int rounds, int blocks); 47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 48 int rounds, int blocks, u8 iv[]); 49 asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], 50 int rounds, int bytes, u8 ctr[]); 51 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[], 52 u32 const rk1[], int rounds, int bytes, 53 u32 const rk2[], u8 iv[], int first); 54 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[], 55 u32 const rk1[], int rounds, int bytes, 56 u32 const rk2[], u8 iv[], int first); 57 58 struct aesbs_ctx { 59 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 60 int rounds; 61 } __aligned(AES_BLOCK_SIZE); 62 63 struct aesbs_cbc_ctr_ctx { 64 struct aesbs_ctx key; 65 u32 enc[AES_MAX_KEYLENGTH_U32]; 66 }; 67 68 struct aesbs_xts_ctx { 69 struct aesbs_ctx key; 70 u32 twkey[AES_MAX_KEYLENGTH_U32]; 71 struct crypto_aes_ctx cts; 72 }; 73 74 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 75 unsigned int key_len) 76 { 77 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 78 struct crypto_aes_ctx rk; 79 int err; 80 81 err = aes_expandkey(&rk, in_key, key_len); 82 if (err) 83 return err; 84 85 ctx->rounds = 6 + key_len / 4; 86 87 kernel_neon_begin(); 88 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 89 kernel_neon_end(); 90 91 return 0; 92 } 93 94 static int __ecb_crypt(struct skcipher_request *req, 95 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 96 int rounds, int blocks)) 97 { 98 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 99 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 100 struct skcipher_walk walk; 101 int err; 102 103 err = skcipher_walk_virt(&walk, req, false); 104 105 while (walk.nbytes >= AES_BLOCK_SIZE) { 106 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 107 108 if (walk.nbytes < walk.total) 109 blocks = round_down(blocks, 110 walk.stride / AES_BLOCK_SIZE); 111 112 kernel_neon_begin(); 113 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 114 ctx->rounds, blocks); 115 kernel_neon_end(); 116 err = skcipher_walk_done(&walk, 117 walk.nbytes - blocks * AES_BLOCK_SIZE); 118 } 119 120 return err; 121 } 122 123 static int ecb_encrypt(struct skcipher_request *req) 124 { 125 return __ecb_crypt(req, aesbs_ecb_encrypt); 126 } 127 128 static int ecb_decrypt(struct skcipher_request *req) 129 { 130 return __ecb_crypt(req, aesbs_ecb_decrypt); 131 } 132 133 static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 134 unsigned int key_len) 135 { 136 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 137 struct crypto_aes_ctx rk; 138 int err; 139 140 err = aes_expandkey(&rk, in_key, key_len); 141 if (err) 142 return err; 143 144 ctx->key.rounds = 6 + key_len / 4; 145 146 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 147 148 kernel_neon_begin(); 149 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 150 kernel_neon_end(); 151 memzero_explicit(&rk, sizeof(rk)); 152 153 return 0; 154 } 155 156 static int cbc_encrypt(struct skcipher_request *req) 157 { 158 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 159 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 160 struct skcipher_walk walk; 161 int err; 162 163 err = skcipher_walk_virt(&walk, req, false); 164 165 while (walk.nbytes >= AES_BLOCK_SIZE) { 166 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 167 168 /* fall back to the non-bitsliced NEON implementation */ 169 kernel_neon_begin(); 170 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 171 ctx->enc, ctx->key.rounds, blocks, 172 walk.iv); 173 kernel_neon_end(); 174 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 175 } 176 return err; 177 } 178 179 static int cbc_decrypt(struct skcipher_request *req) 180 { 181 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 182 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 183 struct skcipher_walk walk; 184 int err; 185 186 err = skcipher_walk_virt(&walk, req, false); 187 188 while (walk.nbytes >= AES_BLOCK_SIZE) { 189 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 190 191 if (walk.nbytes < walk.total) 192 blocks = round_down(blocks, 193 walk.stride / AES_BLOCK_SIZE); 194 195 kernel_neon_begin(); 196 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 197 ctx->key.rk, ctx->key.rounds, blocks, 198 walk.iv); 199 kernel_neon_end(); 200 err = skcipher_walk_done(&walk, 201 walk.nbytes - blocks * AES_BLOCK_SIZE); 202 } 203 204 return err; 205 } 206 207 static int ctr_encrypt(struct skcipher_request *req) 208 { 209 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 210 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 211 struct skcipher_walk walk; 212 int err; 213 214 err = skcipher_walk_virt(&walk, req, false); 215 216 while (walk.nbytes > 0) { 217 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 218 int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE); 219 const u8 *src = walk.src.virt.addr; 220 u8 *dst = walk.dst.virt.addr; 221 222 kernel_neon_begin(); 223 if (blocks >= 8) { 224 aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds, 225 blocks, walk.iv); 226 dst += blocks * AES_BLOCK_SIZE; 227 src += blocks * AES_BLOCK_SIZE; 228 } 229 if (nbytes && walk.nbytes == walk.total) { 230 neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds, 231 nbytes, walk.iv); 232 nbytes = 0; 233 } 234 kernel_neon_end(); 235 err = skcipher_walk_done(&walk, nbytes); 236 } 237 return err; 238 } 239 240 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 241 unsigned int key_len) 242 { 243 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 244 struct crypto_aes_ctx rk; 245 int err; 246 247 err = xts_verify_key(tfm, in_key, key_len); 248 if (err) 249 return err; 250 251 key_len /= 2; 252 err = aes_expandkey(&ctx->cts, in_key, key_len); 253 if (err) 254 return err; 255 256 err = aes_expandkey(&rk, in_key + key_len, key_len); 257 if (err) 258 return err; 259 260 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 261 262 return aesbs_setkey(tfm, in_key, key_len); 263 } 264 265 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 266 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 267 int rounds, int blocks, u8 iv[])) 268 { 269 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 270 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 271 int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); 272 struct scatterlist sg_src[2], sg_dst[2]; 273 struct skcipher_request subreq; 274 struct scatterlist *src, *dst; 275 struct skcipher_walk walk; 276 int nbytes, err; 277 int first = 1; 278 u8 *out, *in; 279 280 if (req->cryptlen < AES_BLOCK_SIZE) 281 return -EINVAL; 282 283 /* ensure that the cts tail is covered by a single step */ 284 if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { 285 int xts_blocks = DIV_ROUND_UP(req->cryptlen, 286 AES_BLOCK_SIZE) - 2; 287 288 skcipher_request_set_tfm(&subreq, tfm); 289 skcipher_request_set_callback(&subreq, 290 skcipher_request_flags(req), 291 NULL, NULL); 292 skcipher_request_set_crypt(&subreq, req->src, req->dst, 293 xts_blocks * AES_BLOCK_SIZE, 294 req->iv); 295 req = &subreq; 296 } else { 297 tail = 0; 298 } 299 300 err = skcipher_walk_virt(&walk, req, false); 301 if (err) 302 return err; 303 304 while (walk.nbytes >= AES_BLOCK_SIZE) { 305 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 306 out = walk.dst.virt.addr; 307 in = walk.src.virt.addr; 308 nbytes = walk.nbytes; 309 310 kernel_neon_begin(); 311 if (blocks >= 8) { 312 if (first == 1) 313 neon_aes_ecb_encrypt(walk.iv, walk.iv, 314 ctx->twkey, 315 ctx->key.rounds, 1); 316 first = 2; 317 318 fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, 319 walk.iv); 320 321 out += blocks * AES_BLOCK_SIZE; 322 in += blocks * AES_BLOCK_SIZE; 323 nbytes -= blocks * AES_BLOCK_SIZE; 324 } 325 if (walk.nbytes == walk.total && nbytes > 0) { 326 if (encrypt) 327 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, 328 ctx->key.rounds, nbytes, 329 ctx->twkey, walk.iv, first); 330 else 331 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, 332 ctx->key.rounds, nbytes, 333 ctx->twkey, walk.iv, first); 334 nbytes = first = 0; 335 } 336 kernel_neon_end(); 337 err = skcipher_walk_done(&walk, nbytes); 338 } 339 340 if (err || likely(!tail)) 341 return err; 342 343 /* handle ciphertext stealing */ 344 dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen); 345 if (req->dst != req->src) 346 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen); 347 348 skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, 349 req->iv); 350 351 err = skcipher_walk_virt(&walk, req, false); 352 if (err) 353 return err; 354 355 out = walk.dst.virt.addr; 356 in = walk.src.virt.addr; 357 nbytes = walk.nbytes; 358 359 kernel_neon_begin(); 360 if (encrypt) 361 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds, 362 nbytes, ctx->twkey, walk.iv, first); 363 else 364 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds, 365 nbytes, ctx->twkey, walk.iv, first); 366 kernel_neon_end(); 367 368 return skcipher_walk_done(&walk, 0); 369 } 370 371 static int xts_encrypt(struct skcipher_request *req) 372 { 373 return __xts_crypt(req, true, aesbs_xts_encrypt); 374 } 375 376 static int xts_decrypt(struct skcipher_request *req) 377 { 378 return __xts_crypt(req, false, aesbs_xts_decrypt); 379 } 380 381 static struct skcipher_alg aes_algs[] = { { 382 .base.cra_name = "ecb(aes)", 383 .base.cra_driver_name = "ecb-aes-neonbs", 384 .base.cra_priority = 250, 385 .base.cra_blocksize = AES_BLOCK_SIZE, 386 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 387 .base.cra_module = THIS_MODULE, 388 389 .min_keysize = AES_MIN_KEY_SIZE, 390 .max_keysize = AES_MAX_KEY_SIZE, 391 .walksize = 8 * AES_BLOCK_SIZE, 392 .setkey = aesbs_setkey, 393 .encrypt = ecb_encrypt, 394 .decrypt = ecb_decrypt, 395 }, { 396 .base.cra_name = "cbc(aes)", 397 .base.cra_driver_name = "cbc-aes-neonbs", 398 .base.cra_priority = 250, 399 .base.cra_blocksize = AES_BLOCK_SIZE, 400 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 401 .base.cra_module = THIS_MODULE, 402 403 .min_keysize = AES_MIN_KEY_SIZE, 404 .max_keysize = AES_MAX_KEY_SIZE, 405 .walksize = 8 * AES_BLOCK_SIZE, 406 .ivsize = AES_BLOCK_SIZE, 407 .setkey = aesbs_cbc_ctr_setkey, 408 .encrypt = cbc_encrypt, 409 .decrypt = cbc_decrypt, 410 }, { 411 .base.cra_name = "ctr(aes)", 412 .base.cra_driver_name = "ctr-aes-neonbs", 413 .base.cra_priority = 250, 414 .base.cra_blocksize = 1, 415 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 416 .base.cra_module = THIS_MODULE, 417 418 .min_keysize = AES_MIN_KEY_SIZE, 419 .max_keysize = AES_MAX_KEY_SIZE, 420 .chunksize = AES_BLOCK_SIZE, 421 .walksize = 8 * AES_BLOCK_SIZE, 422 .ivsize = AES_BLOCK_SIZE, 423 .setkey = aesbs_cbc_ctr_setkey, 424 .encrypt = ctr_encrypt, 425 .decrypt = ctr_encrypt, 426 }, { 427 .base.cra_name = "xts(aes)", 428 .base.cra_driver_name = "xts-aes-neonbs", 429 .base.cra_priority = 250, 430 .base.cra_blocksize = AES_BLOCK_SIZE, 431 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 432 .base.cra_module = THIS_MODULE, 433 434 .min_keysize = 2 * AES_MIN_KEY_SIZE, 435 .max_keysize = 2 * AES_MAX_KEY_SIZE, 436 .walksize = 8 * AES_BLOCK_SIZE, 437 .ivsize = AES_BLOCK_SIZE, 438 .setkey = aesbs_xts_setkey, 439 .encrypt = xts_encrypt, 440 .decrypt = xts_decrypt, 441 } }; 442 443 static void aes_exit(void) 444 { 445 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 446 } 447 448 static int __init aes_init(void) 449 { 450 if (!cpu_have_named_feature(ASIMD)) 451 return -ENODEV; 452 453 return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 454 } 455 456 module_init(aes_init); 457 module_exit(aes_exit); 458