1 /* SPDX-License-Identifier: GPL-2.0-or-later */ 2 /* 3 * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions 4 * as specified in 5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html 6 * 7 * Copyright (C) 2022, Alibaba Group. 8 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com> 9 */ 10 11 #include <linux/module.h> 12 #include <linux/crypto.h> 13 #include <linux/kernel.h> 14 #include <linux/cpufeature.h> 15 #include <asm/neon.h> 16 #include <asm/simd.h> 17 #include <crypto/internal/simd.h> 18 #include <crypto/internal/skcipher.h> 19 #include <crypto/sm4.h> 20 21 #define BYTES2BLKS(nbytes) ((nbytes) >> 4) 22 23 asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec, 24 const u32 *fk, const u32 *ck); 25 asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src); 26 asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src, 27 unsigned int nblks); 28 asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src, 29 u8 *iv, unsigned int nblocks); 30 asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src, 31 u8 *iv, unsigned int nblocks); 32 asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src, 33 u8 *iv, unsigned int nblks); 34 asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src, 35 u8 *iv, unsigned int nblks); 36 asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src, 37 u8 *iv, unsigned int nblks); 38 39 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key, 40 unsigned int key_len) 41 { 42 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 43 44 if (key_len != SM4_KEY_SIZE) 45 return -EINVAL; 46 47 sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec, 48 crypto_sm4_fk, crypto_sm4_ck); 49 return 0; 50 } 51 52 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey) 53 { 54 struct skcipher_walk walk; 55 unsigned int nbytes; 56 int err; 57 58 err = skcipher_walk_virt(&walk, req, false); 59 60 while ((nbytes = walk.nbytes) > 0) { 61 const u8 *src = walk.src.virt.addr; 62 u8 *dst = walk.dst.virt.addr; 63 unsigned int nblks; 64 65 kernel_neon_begin(); 66 67 nblks = BYTES2BLKS(nbytes); 68 if (nblks) { 69 sm4_ce_crypt(rkey, dst, src, nblks); 70 nbytes -= nblks * SM4_BLOCK_SIZE; 71 } 72 73 kernel_neon_end(); 74 75 err = skcipher_walk_done(&walk, nbytes); 76 } 77 78 return err; 79 } 80 81 static int sm4_ecb_encrypt(struct skcipher_request *req) 82 { 83 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 84 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 85 86 return sm4_ecb_do_crypt(req, ctx->rkey_enc); 87 } 88 89 static int sm4_ecb_decrypt(struct skcipher_request *req) 90 { 91 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 92 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 93 94 return sm4_ecb_do_crypt(req, ctx->rkey_dec); 95 } 96 97 static int sm4_cbc_crypt(struct skcipher_request *req, 98 struct sm4_ctx *ctx, bool encrypt) 99 { 100 struct skcipher_walk walk; 101 unsigned int nbytes; 102 int err; 103 104 err = skcipher_walk_virt(&walk, req, false); 105 if (err) 106 return err; 107 108 while ((nbytes = walk.nbytes) > 0) { 109 const u8 *src = walk.src.virt.addr; 110 u8 *dst = walk.dst.virt.addr; 111 unsigned int nblocks; 112 113 nblocks = nbytes / SM4_BLOCK_SIZE; 114 if (nblocks) { 115 kernel_neon_begin(); 116 117 if (encrypt) 118 sm4_ce_cbc_enc(ctx->rkey_enc, dst, src, 119 walk.iv, nblocks); 120 else 121 sm4_ce_cbc_dec(ctx->rkey_dec, dst, src, 122 walk.iv, nblocks); 123 124 kernel_neon_end(); 125 } 126 127 err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE); 128 } 129 130 return err; 131 } 132 133 static int sm4_cbc_encrypt(struct skcipher_request *req) 134 { 135 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 136 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 137 138 return sm4_cbc_crypt(req, ctx, true); 139 } 140 141 static int sm4_cbc_decrypt(struct skcipher_request *req) 142 { 143 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 144 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 145 146 return sm4_cbc_crypt(req, ctx, false); 147 } 148 149 static int sm4_cfb_encrypt(struct skcipher_request *req) 150 { 151 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 152 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 153 struct skcipher_walk walk; 154 unsigned int nbytes; 155 int err; 156 157 err = skcipher_walk_virt(&walk, req, false); 158 159 while ((nbytes = walk.nbytes) > 0) { 160 const u8 *src = walk.src.virt.addr; 161 u8 *dst = walk.dst.virt.addr; 162 unsigned int nblks; 163 164 kernel_neon_begin(); 165 166 nblks = BYTES2BLKS(nbytes); 167 if (nblks) { 168 sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks); 169 dst += nblks * SM4_BLOCK_SIZE; 170 src += nblks * SM4_BLOCK_SIZE; 171 nbytes -= nblks * SM4_BLOCK_SIZE; 172 } 173 174 /* tail */ 175 if (walk.nbytes == walk.total && nbytes > 0) { 176 u8 keystream[SM4_BLOCK_SIZE]; 177 178 sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv); 179 crypto_xor_cpy(dst, src, keystream, nbytes); 180 nbytes = 0; 181 } 182 183 kernel_neon_end(); 184 185 err = skcipher_walk_done(&walk, nbytes); 186 } 187 188 return err; 189 } 190 191 static int sm4_cfb_decrypt(struct skcipher_request *req) 192 { 193 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 194 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 195 struct skcipher_walk walk; 196 unsigned int nbytes; 197 int err; 198 199 err = skcipher_walk_virt(&walk, req, false); 200 201 while ((nbytes = walk.nbytes) > 0) { 202 const u8 *src = walk.src.virt.addr; 203 u8 *dst = walk.dst.virt.addr; 204 unsigned int nblks; 205 206 kernel_neon_begin(); 207 208 nblks = BYTES2BLKS(nbytes); 209 if (nblks) { 210 sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks); 211 dst += nblks * SM4_BLOCK_SIZE; 212 src += nblks * SM4_BLOCK_SIZE; 213 nbytes -= nblks * SM4_BLOCK_SIZE; 214 } 215 216 /* tail */ 217 if (walk.nbytes == walk.total && nbytes > 0) { 218 u8 keystream[SM4_BLOCK_SIZE]; 219 220 sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv); 221 crypto_xor_cpy(dst, src, keystream, nbytes); 222 nbytes = 0; 223 } 224 225 kernel_neon_end(); 226 227 err = skcipher_walk_done(&walk, nbytes); 228 } 229 230 return err; 231 } 232 233 static int sm4_ctr_crypt(struct skcipher_request *req) 234 { 235 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 236 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 237 struct skcipher_walk walk; 238 unsigned int nbytes; 239 int err; 240 241 err = skcipher_walk_virt(&walk, req, false); 242 243 while ((nbytes = walk.nbytes) > 0) { 244 const u8 *src = walk.src.virt.addr; 245 u8 *dst = walk.dst.virt.addr; 246 unsigned int nblks; 247 248 kernel_neon_begin(); 249 250 nblks = BYTES2BLKS(nbytes); 251 if (nblks) { 252 sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks); 253 dst += nblks * SM4_BLOCK_SIZE; 254 src += nblks * SM4_BLOCK_SIZE; 255 nbytes -= nblks * SM4_BLOCK_SIZE; 256 } 257 258 /* tail */ 259 if (walk.nbytes == walk.total && nbytes > 0) { 260 u8 keystream[SM4_BLOCK_SIZE]; 261 262 sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv); 263 crypto_inc(walk.iv, SM4_BLOCK_SIZE); 264 crypto_xor_cpy(dst, src, keystream, nbytes); 265 nbytes = 0; 266 } 267 268 kernel_neon_end(); 269 270 err = skcipher_walk_done(&walk, nbytes); 271 } 272 273 return err; 274 } 275 276 static struct skcipher_alg sm4_algs[] = { 277 { 278 .base = { 279 .cra_name = "ecb(sm4)", 280 .cra_driver_name = "ecb-sm4-ce", 281 .cra_priority = 400, 282 .cra_blocksize = SM4_BLOCK_SIZE, 283 .cra_ctxsize = sizeof(struct sm4_ctx), 284 .cra_module = THIS_MODULE, 285 }, 286 .min_keysize = SM4_KEY_SIZE, 287 .max_keysize = SM4_KEY_SIZE, 288 .setkey = sm4_setkey, 289 .encrypt = sm4_ecb_encrypt, 290 .decrypt = sm4_ecb_decrypt, 291 }, { 292 .base = { 293 .cra_name = "cbc(sm4)", 294 .cra_driver_name = "cbc-sm4-ce", 295 .cra_priority = 400, 296 .cra_blocksize = SM4_BLOCK_SIZE, 297 .cra_ctxsize = sizeof(struct sm4_ctx), 298 .cra_module = THIS_MODULE, 299 }, 300 .min_keysize = SM4_KEY_SIZE, 301 .max_keysize = SM4_KEY_SIZE, 302 .ivsize = SM4_BLOCK_SIZE, 303 .setkey = sm4_setkey, 304 .encrypt = sm4_cbc_encrypt, 305 .decrypt = sm4_cbc_decrypt, 306 }, { 307 .base = { 308 .cra_name = "cfb(sm4)", 309 .cra_driver_name = "cfb-sm4-ce", 310 .cra_priority = 400, 311 .cra_blocksize = 1, 312 .cra_ctxsize = sizeof(struct sm4_ctx), 313 .cra_module = THIS_MODULE, 314 }, 315 .min_keysize = SM4_KEY_SIZE, 316 .max_keysize = SM4_KEY_SIZE, 317 .ivsize = SM4_BLOCK_SIZE, 318 .chunksize = SM4_BLOCK_SIZE, 319 .setkey = sm4_setkey, 320 .encrypt = sm4_cfb_encrypt, 321 .decrypt = sm4_cfb_decrypt, 322 }, { 323 .base = { 324 .cra_name = "ctr(sm4)", 325 .cra_driver_name = "ctr-sm4-ce", 326 .cra_priority = 400, 327 .cra_blocksize = 1, 328 .cra_ctxsize = sizeof(struct sm4_ctx), 329 .cra_module = THIS_MODULE, 330 }, 331 .min_keysize = SM4_KEY_SIZE, 332 .max_keysize = SM4_KEY_SIZE, 333 .ivsize = SM4_BLOCK_SIZE, 334 .chunksize = SM4_BLOCK_SIZE, 335 .setkey = sm4_setkey, 336 .encrypt = sm4_ctr_crypt, 337 .decrypt = sm4_ctr_crypt, 338 } 339 }; 340 341 static int __init sm4_init(void) 342 { 343 return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs)); 344 } 345 346 static void __exit sm4_exit(void) 347 { 348 crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs)); 349 } 350 351 module_cpu_feature_match(SM4, sm4_init); 352 module_exit(sm4_exit); 353 354 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 Crypto Extensions"); 355 MODULE_ALIAS_CRYPTO("sm4-ce"); 356 MODULE_ALIAS_CRYPTO("sm4"); 357 MODULE_ALIAS_CRYPTO("ecb(sm4)"); 358 MODULE_ALIAS_CRYPTO("cbc(sm4)"); 359 MODULE_ALIAS_CRYPTO("cfb(sm4)"); 360 MODULE_ALIAS_CRYPTO("ctr(sm4)"); 361 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>"); 362 MODULE_LICENSE("GPL v2"); 363