1 /* SPDX-License-Identifier: GPL-2.0-or-later */ 2 /* 3 * SM4 Cipher Algorithm, using ARMv8 NEON 4 * as specified in 5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html 6 * 7 * Copyright (C) 2022, Alibaba Group. 8 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com> 9 */ 10 11 #include <linux/module.h> 12 #include <linux/crypto.h> 13 #include <linux/kernel.h> 14 #include <linux/cpufeature.h> 15 #include <asm/neon.h> 16 #include <asm/simd.h> 17 #include <crypto/internal/simd.h> 18 #include <crypto/internal/skcipher.h> 19 #include <crypto/sm4.h> 20 21 #define BYTES2BLKS(nbytes) ((nbytes) >> 4) 22 #define BYTES2BLK8(nbytes) (((nbytes) >> 4) & ~(8 - 1)) 23 24 asmlinkage void sm4_neon_crypt_blk1_8(const u32 *rkey, u8 *dst, const u8 *src, 25 unsigned int nblks); 26 asmlinkage void sm4_neon_crypt_blk8(const u32 *rkey, u8 *dst, const u8 *src, 27 unsigned int nblks); 28 asmlinkage void sm4_neon_cbc_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src, 29 u8 *iv, unsigned int nblks); 30 asmlinkage void sm4_neon_cfb_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src, 31 u8 *iv, unsigned int nblks); 32 asmlinkage void sm4_neon_ctr_enc_blk8(const u32 *rkey, u8 *dst, const u8 *src, 33 u8 *iv, unsigned int nblks); 34 35 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key, 36 unsigned int key_len) 37 { 38 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 39 40 return sm4_expandkey(ctx, key, key_len); 41 } 42 43 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey) 44 { 45 struct skcipher_walk walk; 46 unsigned int nbytes; 47 int err; 48 49 err = skcipher_walk_virt(&walk, req, false); 50 51 while ((nbytes = walk.nbytes) > 0) { 52 const u8 *src = walk.src.virt.addr; 53 u8 *dst = walk.dst.virt.addr; 54 unsigned int nblks; 55 56 kernel_neon_begin(); 57 58 nblks = BYTES2BLK8(nbytes); 59 if (nblks) { 60 sm4_neon_crypt_blk8(rkey, dst, src, nblks); 61 dst += nblks * SM4_BLOCK_SIZE; 62 src += nblks * SM4_BLOCK_SIZE; 63 nbytes -= nblks * SM4_BLOCK_SIZE; 64 } 65 66 nblks = BYTES2BLKS(nbytes); 67 if (nblks) { 68 sm4_neon_crypt_blk1_8(rkey, dst, src, nblks); 69 nbytes -= nblks * SM4_BLOCK_SIZE; 70 } 71 72 kernel_neon_end(); 73 74 err = skcipher_walk_done(&walk, nbytes); 75 } 76 77 return err; 78 } 79 80 static int sm4_ecb_encrypt(struct skcipher_request *req) 81 { 82 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 83 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 84 85 return sm4_ecb_do_crypt(req, ctx->rkey_enc); 86 } 87 88 static int sm4_ecb_decrypt(struct skcipher_request *req) 89 { 90 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 91 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 92 93 return sm4_ecb_do_crypt(req, ctx->rkey_dec); 94 } 95 96 static int sm4_cbc_encrypt(struct skcipher_request *req) 97 { 98 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 99 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 100 struct skcipher_walk walk; 101 unsigned int nbytes; 102 int err; 103 104 err = skcipher_walk_virt(&walk, req, false); 105 106 while ((nbytes = walk.nbytes) > 0) { 107 const u8 *iv = walk.iv; 108 const u8 *src = walk.src.virt.addr; 109 u8 *dst = walk.dst.virt.addr; 110 111 while (nbytes >= SM4_BLOCK_SIZE) { 112 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE); 113 sm4_crypt_block(ctx->rkey_enc, dst, dst); 114 iv = dst; 115 src += SM4_BLOCK_SIZE; 116 dst += SM4_BLOCK_SIZE; 117 nbytes -= SM4_BLOCK_SIZE; 118 } 119 if (iv != walk.iv) 120 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 121 122 err = skcipher_walk_done(&walk, nbytes); 123 } 124 125 return err; 126 } 127 128 static int sm4_cbc_decrypt(struct skcipher_request *req) 129 { 130 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 131 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 132 struct skcipher_walk walk; 133 unsigned int nbytes; 134 int err; 135 136 err = skcipher_walk_virt(&walk, req, false); 137 138 while ((nbytes = walk.nbytes) > 0) { 139 const u8 *src = walk.src.virt.addr; 140 u8 *dst = walk.dst.virt.addr; 141 unsigned int nblks; 142 143 kernel_neon_begin(); 144 145 nblks = BYTES2BLK8(nbytes); 146 if (nblks) { 147 sm4_neon_cbc_dec_blk8(ctx->rkey_dec, dst, src, 148 walk.iv, nblks); 149 dst += nblks * SM4_BLOCK_SIZE; 150 src += nblks * SM4_BLOCK_SIZE; 151 nbytes -= nblks * SM4_BLOCK_SIZE; 152 } 153 154 nblks = BYTES2BLKS(nbytes); 155 if (nblks) { 156 u8 keystream[SM4_BLOCK_SIZE * 8]; 157 u8 iv[SM4_BLOCK_SIZE]; 158 int i; 159 160 sm4_neon_crypt_blk1_8(ctx->rkey_dec, keystream, 161 src, nblks); 162 163 src += ((int)nblks - 2) * SM4_BLOCK_SIZE; 164 dst += (nblks - 1) * SM4_BLOCK_SIZE; 165 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE); 166 167 for (i = nblks - 1; i > 0; i--) { 168 crypto_xor_cpy(dst, src, 169 &keystream[i * SM4_BLOCK_SIZE], 170 SM4_BLOCK_SIZE); 171 src -= SM4_BLOCK_SIZE; 172 dst -= SM4_BLOCK_SIZE; 173 } 174 crypto_xor_cpy(dst, walk.iv, 175 keystream, SM4_BLOCK_SIZE); 176 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 177 nbytes -= nblks * SM4_BLOCK_SIZE; 178 } 179 180 kernel_neon_end(); 181 182 err = skcipher_walk_done(&walk, nbytes); 183 } 184 185 return err; 186 } 187 188 static int sm4_cfb_encrypt(struct skcipher_request *req) 189 { 190 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 191 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 192 struct skcipher_walk walk; 193 unsigned int nbytes; 194 int err; 195 196 err = skcipher_walk_virt(&walk, req, false); 197 198 while ((nbytes = walk.nbytes) > 0) { 199 u8 keystream[SM4_BLOCK_SIZE]; 200 const u8 *iv = walk.iv; 201 const u8 *src = walk.src.virt.addr; 202 u8 *dst = walk.dst.virt.addr; 203 204 while (nbytes >= SM4_BLOCK_SIZE) { 205 sm4_crypt_block(ctx->rkey_enc, keystream, iv); 206 crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE); 207 iv = dst; 208 src += SM4_BLOCK_SIZE; 209 dst += SM4_BLOCK_SIZE; 210 nbytes -= SM4_BLOCK_SIZE; 211 } 212 if (iv != walk.iv) 213 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 214 215 /* tail */ 216 if (walk.nbytes == walk.total && nbytes > 0) { 217 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 218 crypto_xor_cpy(dst, src, keystream, nbytes); 219 nbytes = 0; 220 } 221 222 err = skcipher_walk_done(&walk, nbytes); 223 } 224 225 return err; 226 } 227 228 static int sm4_cfb_decrypt(struct skcipher_request *req) 229 { 230 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 231 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 232 struct skcipher_walk walk; 233 unsigned int nbytes; 234 int err; 235 236 err = skcipher_walk_virt(&walk, req, false); 237 238 while ((nbytes = walk.nbytes) > 0) { 239 const u8 *src = walk.src.virt.addr; 240 u8 *dst = walk.dst.virt.addr; 241 unsigned int nblks; 242 243 kernel_neon_begin(); 244 245 nblks = BYTES2BLK8(nbytes); 246 if (nblks) { 247 sm4_neon_cfb_dec_blk8(ctx->rkey_enc, dst, src, 248 walk.iv, nblks); 249 dst += nblks * SM4_BLOCK_SIZE; 250 src += nblks * SM4_BLOCK_SIZE; 251 nbytes -= nblks * SM4_BLOCK_SIZE; 252 } 253 254 nblks = BYTES2BLKS(nbytes); 255 if (nblks) { 256 u8 keystream[SM4_BLOCK_SIZE * 8]; 257 258 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE); 259 if (nblks > 1) 260 memcpy(&keystream[SM4_BLOCK_SIZE], src, 261 (nblks - 1) * SM4_BLOCK_SIZE); 262 memcpy(walk.iv, src + (nblks - 1) * SM4_BLOCK_SIZE, 263 SM4_BLOCK_SIZE); 264 265 sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream, 266 keystream, nblks); 267 268 crypto_xor_cpy(dst, src, keystream, 269 nblks * SM4_BLOCK_SIZE); 270 dst += nblks * SM4_BLOCK_SIZE; 271 src += nblks * SM4_BLOCK_SIZE; 272 nbytes -= nblks * SM4_BLOCK_SIZE; 273 } 274 275 kernel_neon_end(); 276 277 /* tail */ 278 if (walk.nbytes == walk.total && nbytes > 0) { 279 u8 keystream[SM4_BLOCK_SIZE]; 280 281 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 282 crypto_xor_cpy(dst, src, keystream, nbytes); 283 nbytes = 0; 284 } 285 286 err = skcipher_walk_done(&walk, nbytes); 287 } 288 289 return err; 290 } 291 292 static int sm4_ctr_crypt(struct skcipher_request *req) 293 { 294 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 295 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 296 struct skcipher_walk walk; 297 unsigned int nbytes; 298 int err; 299 300 err = skcipher_walk_virt(&walk, req, false); 301 302 while ((nbytes = walk.nbytes) > 0) { 303 const u8 *src = walk.src.virt.addr; 304 u8 *dst = walk.dst.virt.addr; 305 unsigned int nblks; 306 307 kernel_neon_begin(); 308 309 nblks = BYTES2BLK8(nbytes); 310 if (nblks) { 311 sm4_neon_ctr_enc_blk8(ctx->rkey_enc, dst, src, 312 walk.iv, nblks); 313 dst += nblks * SM4_BLOCK_SIZE; 314 src += nblks * SM4_BLOCK_SIZE; 315 nbytes -= nblks * SM4_BLOCK_SIZE; 316 } 317 318 nblks = BYTES2BLKS(nbytes); 319 if (nblks) { 320 u8 keystream[SM4_BLOCK_SIZE * 8]; 321 int i; 322 323 for (i = 0; i < nblks; i++) { 324 memcpy(&keystream[i * SM4_BLOCK_SIZE], 325 walk.iv, SM4_BLOCK_SIZE); 326 crypto_inc(walk.iv, SM4_BLOCK_SIZE); 327 } 328 sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream, 329 keystream, nblks); 330 331 crypto_xor_cpy(dst, src, keystream, 332 nblks * SM4_BLOCK_SIZE); 333 dst += nblks * SM4_BLOCK_SIZE; 334 src += nblks * SM4_BLOCK_SIZE; 335 nbytes -= nblks * SM4_BLOCK_SIZE; 336 } 337 338 kernel_neon_end(); 339 340 /* tail */ 341 if (walk.nbytes == walk.total && nbytes > 0) { 342 u8 keystream[SM4_BLOCK_SIZE]; 343 344 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 345 crypto_inc(walk.iv, SM4_BLOCK_SIZE); 346 crypto_xor_cpy(dst, src, keystream, nbytes); 347 nbytes = 0; 348 } 349 350 err = skcipher_walk_done(&walk, nbytes); 351 } 352 353 return err; 354 } 355 356 static struct skcipher_alg sm4_algs[] = { 357 { 358 .base = { 359 .cra_name = "ecb(sm4)", 360 .cra_driver_name = "ecb-sm4-neon", 361 .cra_priority = 200, 362 .cra_blocksize = SM4_BLOCK_SIZE, 363 .cra_ctxsize = sizeof(struct sm4_ctx), 364 .cra_module = THIS_MODULE, 365 }, 366 .min_keysize = SM4_KEY_SIZE, 367 .max_keysize = SM4_KEY_SIZE, 368 .setkey = sm4_setkey, 369 .encrypt = sm4_ecb_encrypt, 370 .decrypt = sm4_ecb_decrypt, 371 }, { 372 .base = { 373 .cra_name = "cbc(sm4)", 374 .cra_driver_name = "cbc-sm4-neon", 375 .cra_priority = 200, 376 .cra_blocksize = SM4_BLOCK_SIZE, 377 .cra_ctxsize = sizeof(struct sm4_ctx), 378 .cra_module = THIS_MODULE, 379 }, 380 .min_keysize = SM4_KEY_SIZE, 381 .max_keysize = SM4_KEY_SIZE, 382 .ivsize = SM4_BLOCK_SIZE, 383 .setkey = sm4_setkey, 384 .encrypt = sm4_cbc_encrypt, 385 .decrypt = sm4_cbc_decrypt, 386 }, { 387 .base = { 388 .cra_name = "cfb(sm4)", 389 .cra_driver_name = "cfb-sm4-neon", 390 .cra_priority = 200, 391 .cra_blocksize = 1, 392 .cra_ctxsize = sizeof(struct sm4_ctx), 393 .cra_module = THIS_MODULE, 394 }, 395 .min_keysize = SM4_KEY_SIZE, 396 .max_keysize = SM4_KEY_SIZE, 397 .ivsize = SM4_BLOCK_SIZE, 398 .chunksize = SM4_BLOCK_SIZE, 399 .setkey = sm4_setkey, 400 .encrypt = sm4_cfb_encrypt, 401 .decrypt = sm4_cfb_decrypt, 402 }, { 403 .base = { 404 .cra_name = "ctr(sm4)", 405 .cra_driver_name = "ctr-sm4-neon", 406 .cra_priority = 200, 407 .cra_blocksize = 1, 408 .cra_ctxsize = sizeof(struct sm4_ctx), 409 .cra_module = THIS_MODULE, 410 }, 411 .min_keysize = SM4_KEY_SIZE, 412 .max_keysize = SM4_KEY_SIZE, 413 .ivsize = SM4_BLOCK_SIZE, 414 .chunksize = SM4_BLOCK_SIZE, 415 .setkey = sm4_setkey, 416 .encrypt = sm4_ctr_crypt, 417 .decrypt = sm4_ctr_crypt, 418 } 419 }; 420 421 static int __init sm4_init(void) 422 { 423 return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs)); 424 } 425 426 static void __exit sm4_exit(void) 427 { 428 crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs)); 429 } 430 431 module_init(sm4_init); 432 module_exit(sm4_exit); 433 434 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 NEON"); 435 MODULE_ALIAS_CRYPTO("sm4-neon"); 436 MODULE_ALIAS_CRYPTO("sm4"); 437 MODULE_ALIAS_CRYPTO("ecb(sm4)"); 438 MODULE_ALIAS_CRYPTO("cbc(sm4)"); 439 MODULE_ALIAS_CRYPTO("cfb(sm4)"); 440 MODULE_ALIAS_CRYPTO("ctr(sm4)"); 441 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>"); 442 MODULE_LICENSE("GPL v2"); 443