1 /* SPDX-License-Identifier: GPL-2.0-or-later */ 2 /* 3 * SM4 Cipher Algorithm, using ARMv8 NEON 4 * as specified in 5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html 6 * 7 * Copyright (C) 2022, Alibaba Group. 8 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com> 9 */ 10 11 #include <linux/module.h> 12 #include <linux/crypto.h> 13 #include <linux/kernel.h> 14 #include <linux/cpufeature.h> 15 #include <asm/neon.h> 16 #include <asm/simd.h> 17 #include <crypto/internal/simd.h> 18 #include <crypto/internal/skcipher.h> 19 #include <crypto/sm4.h> 20 21 asmlinkage void sm4_neon_crypt(const u32 *rkey, u8 *dst, const u8 *src, 22 unsigned int nblocks); 23 asmlinkage void sm4_neon_cbc_dec(const u32 *rkey_dec, u8 *dst, const u8 *src, 24 u8 *iv, unsigned int nblocks); 25 asmlinkage void sm4_neon_cfb_dec(const u32 *rkey_enc, u8 *dst, const u8 *src, 26 u8 *iv, unsigned int nblocks); 27 asmlinkage void sm4_neon_ctr_crypt(const u32 *rkey_enc, u8 *dst, const u8 *src, 28 u8 *iv, unsigned int nblocks); 29 30 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key, 31 unsigned int key_len) 32 { 33 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 34 35 return sm4_expandkey(ctx, key, key_len); 36 } 37 38 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey) 39 { 40 struct skcipher_walk walk; 41 unsigned int nbytes; 42 int err; 43 44 err = skcipher_walk_virt(&walk, req, false); 45 46 while ((nbytes = walk.nbytes) > 0) { 47 const u8 *src = walk.src.virt.addr; 48 u8 *dst = walk.dst.virt.addr; 49 unsigned int nblocks; 50 51 nblocks = nbytes / SM4_BLOCK_SIZE; 52 if (nblocks) { 53 kernel_neon_begin(); 54 55 sm4_neon_crypt(rkey, dst, src, nblocks); 56 57 kernel_neon_end(); 58 } 59 60 err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE); 61 } 62 63 return err; 64 } 65 66 static int sm4_ecb_encrypt(struct skcipher_request *req) 67 { 68 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 69 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 70 71 return sm4_ecb_do_crypt(req, ctx->rkey_enc); 72 } 73 74 static int sm4_ecb_decrypt(struct skcipher_request *req) 75 { 76 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 77 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 78 79 return sm4_ecb_do_crypt(req, ctx->rkey_dec); 80 } 81 82 static int sm4_cbc_encrypt(struct skcipher_request *req) 83 { 84 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 85 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 86 struct skcipher_walk walk; 87 unsigned int nbytes; 88 int err; 89 90 err = skcipher_walk_virt(&walk, req, false); 91 92 while ((nbytes = walk.nbytes) > 0) { 93 const u8 *iv = walk.iv; 94 const u8 *src = walk.src.virt.addr; 95 u8 *dst = walk.dst.virt.addr; 96 97 while (nbytes >= SM4_BLOCK_SIZE) { 98 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE); 99 sm4_crypt_block(ctx->rkey_enc, dst, dst); 100 iv = dst; 101 src += SM4_BLOCK_SIZE; 102 dst += SM4_BLOCK_SIZE; 103 nbytes -= SM4_BLOCK_SIZE; 104 } 105 if (iv != walk.iv) 106 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 107 108 err = skcipher_walk_done(&walk, nbytes); 109 } 110 111 return err; 112 } 113 114 static int sm4_cbc_decrypt(struct skcipher_request *req) 115 { 116 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 117 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 118 struct skcipher_walk walk; 119 unsigned int nbytes; 120 int err; 121 122 err = skcipher_walk_virt(&walk, req, false); 123 124 while ((nbytes = walk.nbytes) > 0) { 125 const u8 *src = walk.src.virt.addr; 126 u8 *dst = walk.dst.virt.addr; 127 unsigned int nblocks; 128 129 nblocks = nbytes / SM4_BLOCK_SIZE; 130 if (nblocks) { 131 kernel_neon_begin(); 132 133 sm4_neon_cbc_dec(ctx->rkey_dec, dst, src, 134 walk.iv, nblocks); 135 136 kernel_neon_end(); 137 } 138 139 err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE); 140 } 141 142 return err; 143 } 144 145 static int sm4_cfb_encrypt(struct skcipher_request *req) 146 { 147 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 148 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 149 struct skcipher_walk walk; 150 unsigned int nbytes; 151 int err; 152 153 err = skcipher_walk_virt(&walk, req, false); 154 155 while ((nbytes = walk.nbytes) > 0) { 156 u8 keystream[SM4_BLOCK_SIZE]; 157 const u8 *iv = walk.iv; 158 const u8 *src = walk.src.virt.addr; 159 u8 *dst = walk.dst.virt.addr; 160 161 while (nbytes >= SM4_BLOCK_SIZE) { 162 sm4_crypt_block(ctx->rkey_enc, keystream, iv); 163 crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE); 164 iv = dst; 165 src += SM4_BLOCK_SIZE; 166 dst += SM4_BLOCK_SIZE; 167 nbytes -= SM4_BLOCK_SIZE; 168 } 169 if (iv != walk.iv) 170 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 171 172 /* tail */ 173 if (walk.nbytes == walk.total && nbytes > 0) { 174 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 175 crypto_xor_cpy(dst, src, keystream, nbytes); 176 nbytes = 0; 177 } 178 179 err = skcipher_walk_done(&walk, nbytes); 180 } 181 182 return err; 183 } 184 185 static int sm4_cfb_decrypt(struct skcipher_request *req) 186 { 187 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 188 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 189 struct skcipher_walk walk; 190 unsigned int nbytes; 191 int err; 192 193 err = skcipher_walk_virt(&walk, req, false); 194 195 while ((nbytes = walk.nbytes) > 0) { 196 const u8 *src = walk.src.virt.addr; 197 u8 *dst = walk.dst.virt.addr; 198 unsigned int nblocks; 199 200 nblocks = nbytes / SM4_BLOCK_SIZE; 201 if (nblocks) { 202 kernel_neon_begin(); 203 204 sm4_neon_cfb_dec(ctx->rkey_enc, dst, src, 205 walk.iv, nblocks); 206 207 kernel_neon_end(); 208 209 dst += nblocks * SM4_BLOCK_SIZE; 210 src += nblocks * SM4_BLOCK_SIZE; 211 nbytes -= nblocks * SM4_BLOCK_SIZE; 212 } 213 214 /* tail */ 215 if (walk.nbytes == walk.total && nbytes > 0) { 216 u8 keystream[SM4_BLOCK_SIZE]; 217 218 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 219 crypto_xor_cpy(dst, src, keystream, nbytes); 220 nbytes = 0; 221 } 222 223 err = skcipher_walk_done(&walk, nbytes); 224 } 225 226 return err; 227 } 228 229 static int sm4_ctr_crypt(struct skcipher_request *req) 230 { 231 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 232 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 233 struct skcipher_walk walk; 234 unsigned int nbytes; 235 int err; 236 237 err = skcipher_walk_virt(&walk, req, false); 238 239 while ((nbytes = walk.nbytes) > 0) { 240 const u8 *src = walk.src.virt.addr; 241 u8 *dst = walk.dst.virt.addr; 242 unsigned int nblocks; 243 244 nblocks = nbytes / SM4_BLOCK_SIZE; 245 if (nblocks) { 246 kernel_neon_begin(); 247 248 sm4_neon_ctr_crypt(ctx->rkey_enc, dst, src, 249 walk.iv, nblocks); 250 251 kernel_neon_end(); 252 253 dst += nblocks * SM4_BLOCK_SIZE; 254 src += nblocks * SM4_BLOCK_SIZE; 255 nbytes -= nblocks * SM4_BLOCK_SIZE; 256 } 257 258 /* tail */ 259 if (walk.nbytes == walk.total && nbytes > 0) { 260 u8 keystream[SM4_BLOCK_SIZE]; 261 262 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 263 crypto_inc(walk.iv, SM4_BLOCK_SIZE); 264 crypto_xor_cpy(dst, src, keystream, nbytes); 265 nbytes = 0; 266 } 267 268 err = skcipher_walk_done(&walk, nbytes); 269 } 270 271 return err; 272 } 273 274 static struct skcipher_alg sm4_algs[] = { 275 { 276 .base = { 277 .cra_name = "ecb(sm4)", 278 .cra_driver_name = "ecb-sm4-neon", 279 .cra_priority = 200, 280 .cra_blocksize = SM4_BLOCK_SIZE, 281 .cra_ctxsize = sizeof(struct sm4_ctx), 282 .cra_module = THIS_MODULE, 283 }, 284 .min_keysize = SM4_KEY_SIZE, 285 .max_keysize = SM4_KEY_SIZE, 286 .setkey = sm4_setkey, 287 .encrypt = sm4_ecb_encrypt, 288 .decrypt = sm4_ecb_decrypt, 289 }, { 290 .base = { 291 .cra_name = "cbc(sm4)", 292 .cra_driver_name = "cbc-sm4-neon", 293 .cra_priority = 200, 294 .cra_blocksize = SM4_BLOCK_SIZE, 295 .cra_ctxsize = sizeof(struct sm4_ctx), 296 .cra_module = THIS_MODULE, 297 }, 298 .min_keysize = SM4_KEY_SIZE, 299 .max_keysize = SM4_KEY_SIZE, 300 .ivsize = SM4_BLOCK_SIZE, 301 .setkey = sm4_setkey, 302 .encrypt = sm4_cbc_encrypt, 303 .decrypt = sm4_cbc_decrypt, 304 }, { 305 .base = { 306 .cra_name = "cfb(sm4)", 307 .cra_driver_name = "cfb-sm4-neon", 308 .cra_priority = 200, 309 .cra_blocksize = 1, 310 .cra_ctxsize = sizeof(struct sm4_ctx), 311 .cra_module = THIS_MODULE, 312 }, 313 .min_keysize = SM4_KEY_SIZE, 314 .max_keysize = SM4_KEY_SIZE, 315 .ivsize = SM4_BLOCK_SIZE, 316 .chunksize = SM4_BLOCK_SIZE, 317 .setkey = sm4_setkey, 318 .encrypt = sm4_cfb_encrypt, 319 .decrypt = sm4_cfb_decrypt, 320 }, { 321 .base = { 322 .cra_name = "ctr(sm4)", 323 .cra_driver_name = "ctr-sm4-neon", 324 .cra_priority = 200, 325 .cra_blocksize = 1, 326 .cra_ctxsize = sizeof(struct sm4_ctx), 327 .cra_module = THIS_MODULE, 328 }, 329 .min_keysize = SM4_KEY_SIZE, 330 .max_keysize = SM4_KEY_SIZE, 331 .ivsize = SM4_BLOCK_SIZE, 332 .chunksize = SM4_BLOCK_SIZE, 333 .setkey = sm4_setkey, 334 .encrypt = sm4_ctr_crypt, 335 .decrypt = sm4_ctr_crypt, 336 } 337 }; 338 339 static int __init sm4_init(void) 340 { 341 return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs)); 342 } 343 344 static void __exit sm4_exit(void) 345 { 346 crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs)); 347 } 348 349 module_init(sm4_init); 350 module_exit(sm4_exit); 351 352 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 NEON"); 353 MODULE_ALIAS_CRYPTO("sm4-neon"); 354 MODULE_ALIAS_CRYPTO("sm4"); 355 MODULE_ALIAS_CRYPTO("ecb(sm4)"); 356 MODULE_ALIAS_CRYPTO("cbc(sm4)"); 357 MODULE_ALIAS_CRYPTO("cfb(sm4)"); 358 MODULE_ALIAS_CRYPTO("ctr(sm4)"); 359 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>"); 360 MODULE_LICENSE("GPL v2"); 361