1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, using ARMv8 NEON
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (C) 2022, Alibaba Group.
8  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <linux/cpufeature.h>
15 #include <asm/neon.h>
16 #include <asm/simd.h>
17 #include <crypto/internal/simd.h>
18 #include <crypto/internal/skcipher.h>
19 #include <crypto/sm4.h>
20 
21 #define BYTES2BLKS(nbytes)	((nbytes) >> 4)
22 #define BYTES2BLK8(nbytes)	(((nbytes) >> 4) & ~(8 - 1))
23 
24 asmlinkage void sm4_neon_crypt_blk1_8(const u32 *rkey, u8 *dst, const u8 *src,
25 				      unsigned int nblks);
26 asmlinkage void sm4_neon_crypt_blk8(const u32 *rkey, u8 *dst, const u8 *src,
27 				    unsigned int nblks);
28 asmlinkage void sm4_neon_cbc_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
29 				      u8 *iv, unsigned int nblks);
30 asmlinkage void sm4_neon_cfb_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
31 				      u8 *iv, unsigned int nblks);
32 asmlinkage void sm4_neon_ctr_enc_blk8(const u32 *rkey, u8 *dst, const u8 *src,
33 				      u8 *iv, unsigned int nblks);
34 
35 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
36 		      unsigned int key_len)
37 {
38 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
39 
40 	return sm4_expandkey(ctx, key, key_len);
41 }
42 
43 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
44 {
45 	struct skcipher_walk walk;
46 	unsigned int nbytes;
47 	int err;
48 
49 	err = skcipher_walk_virt(&walk, req, false);
50 
51 	while ((nbytes = walk.nbytes) > 0) {
52 		const u8 *src = walk.src.virt.addr;
53 		u8 *dst = walk.dst.virt.addr;
54 		unsigned int nblks;
55 
56 		kernel_neon_begin();
57 
58 		nblks = BYTES2BLK8(nbytes);
59 		if (nblks) {
60 			sm4_neon_crypt_blk8(rkey, dst, src, nblks);
61 			dst += nblks * SM4_BLOCK_SIZE;
62 			src += nblks * SM4_BLOCK_SIZE;
63 			nbytes -= nblks * SM4_BLOCK_SIZE;
64 		}
65 
66 		nblks = BYTES2BLKS(nbytes);
67 		if (nblks) {
68 			sm4_neon_crypt_blk1_8(rkey, dst, src, nblks);
69 			nbytes -= nblks * SM4_BLOCK_SIZE;
70 		}
71 
72 		kernel_neon_end();
73 
74 		err = skcipher_walk_done(&walk, nbytes);
75 	}
76 
77 	return err;
78 }
79 
80 static int sm4_ecb_encrypt(struct skcipher_request *req)
81 {
82 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
83 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
84 
85 	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
86 }
87 
88 static int sm4_ecb_decrypt(struct skcipher_request *req)
89 {
90 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
91 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
92 
93 	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
94 }
95 
96 static int sm4_cbc_encrypt(struct skcipher_request *req)
97 {
98 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
99 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
100 	struct skcipher_walk walk;
101 	unsigned int nbytes;
102 	int err;
103 
104 	err = skcipher_walk_virt(&walk, req, false);
105 
106 	while ((nbytes = walk.nbytes) > 0) {
107 		const u8 *iv = walk.iv;
108 		const u8 *src = walk.src.virt.addr;
109 		u8 *dst = walk.dst.virt.addr;
110 
111 		while (nbytes >= SM4_BLOCK_SIZE) {
112 			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
113 			sm4_crypt_block(ctx->rkey_enc, dst, dst);
114 			iv = dst;
115 			src += SM4_BLOCK_SIZE;
116 			dst += SM4_BLOCK_SIZE;
117 			nbytes -= SM4_BLOCK_SIZE;
118 		}
119 		if (iv != walk.iv)
120 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
121 
122 		err = skcipher_walk_done(&walk, nbytes);
123 	}
124 
125 	return err;
126 }
127 
128 static int sm4_cbc_decrypt(struct skcipher_request *req)
129 {
130 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
131 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
132 	struct skcipher_walk walk;
133 	unsigned int nbytes;
134 	int err;
135 
136 	err = skcipher_walk_virt(&walk, req, false);
137 
138 	while ((nbytes = walk.nbytes) > 0) {
139 		const u8 *src = walk.src.virt.addr;
140 		u8 *dst = walk.dst.virt.addr;
141 		unsigned int nblks;
142 
143 		kernel_neon_begin();
144 
145 		nblks = BYTES2BLK8(nbytes);
146 		if (nblks) {
147 			sm4_neon_cbc_dec_blk8(ctx->rkey_dec, dst, src,
148 					walk.iv, nblks);
149 			dst += nblks * SM4_BLOCK_SIZE;
150 			src += nblks * SM4_BLOCK_SIZE;
151 			nbytes -= nblks * SM4_BLOCK_SIZE;
152 		}
153 
154 		nblks = BYTES2BLKS(nbytes);
155 		if (nblks) {
156 			u8 keystream[SM4_BLOCK_SIZE * 8];
157 			u8 iv[SM4_BLOCK_SIZE];
158 			int i;
159 
160 			sm4_neon_crypt_blk1_8(ctx->rkey_dec, keystream,
161 					src, nblks);
162 
163 			src += ((int)nblks - 2) * SM4_BLOCK_SIZE;
164 			dst += (nblks - 1) * SM4_BLOCK_SIZE;
165 			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
166 
167 			for (i = nblks - 1; i > 0; i--) {
168 				crypto_xor_cpy(dst, src,
169 					&keystream[i * SM4_BLOCK_SIZE],
170 					SM4_BLOCK_SIZE);
171 				src -= SM4_BLOCK_SIZE;
172 				dst -= SM4_BLOCK_SIZE;
173 			}
174 			crypto_xor_cpy(dst, walk.iv,
175 					keystream, SM4_BLOCK_SIZE);
176 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
177 			nbytes -= nblks * SM4_BLOCK_SIZE;
178 		}
179 
180 		kernel_neon_end();
181 
182 		err = skcipher_walk_done(&walk, nbytes);
183 	}
184 
185 	return err;
186 }
187 
188 static int sm4_cfb_encrypt(struct skcipher_request *req)
189 {
190 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
191 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
192 	struct skcipher_walk walk;
193 	unsigned int nbytes;
194 	int err;
195 
196 	err = skcipher_walk_virt(&walk, req, false);
197 
198 	while ((nbytes = walk.nbytes) > 0) {
199 		u8 keystream[SM4_BLOCK_SIZE];
200 		const u8 *iv = walk.iv;
201 		const u8 *src = walk.src.virt.addr;
202 		u8 *dst = walk.dst.virt.addr;
203 
204 		while (nbytes >= SM4_BLOCK_SIZE) {
205 			sm4_crypt_block(ctx->rkey_enc, keystream, iv);
206 			crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
207 			iv = dst;
208 			src += SM4_BLOCK_SIZE;
209 			dst += SM4_BLOCK_SIZE;
210 			nbytes -= SM4_BLOCK_SIZE;
211 		}
212 		if (iv != walk.iv)
213 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
214 
215 		/* tail */
216 		if (walk.nbytes == walk.total && nbytes > 0) {
217 			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
218 			crypto_xor_cpy(dst, src, keystream, nbytes);
219 			nbytes = 0;
220 		}
221 
222 		err = skcipher_walk_done(&walk, nbytes);
223 	}
224 
225 	return err;
226 }
227 
228 static int sm4_cfb_decrypt(struct skcipher_request *req)
229 {
230 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
231 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
232 	struct skcipher_walk walk;
233 	unsigned int nbytes;
234 	int err;
235 
236 	err = skcipher_walk_virt(&walk, req, false);
237 
238 	while ((nbytes = walk.nbytes) > 0) {
239 		const u8 *src = walk.src.virt.addr;
240 		u8 *dst = walk.dst.virt.addr;
241 		unsigned int nblks;
242 
243 		kernel_neon_begin();
244 
245 		nblks = BYTES2BLK8(nbytes);
246 		if (nblks) {
247 			sm4_neon_cfb_dec_blk8(ctx->rkey_enc, dst, src,
248 					walk.iv, nblks);
249 			dst += nblks * SM4_BLOCK_SIZE;
250 			src += nblks * SM4_BLOCK_SIZE;
251 			nbytes -= nblks * SM4_BLOCK_SIZE;
252 		}
253 
254 		nblks = BYTES2BLKS(nbytes);
255 		if (nblks) {
256 			u8 keystream[SM4_BLOCK_SIZE * 8];
257 
258 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
259 			if (nblks > 1)
260 				memcpy(&keystream[SM4_BLOCK_SIZE], src,
261 					(nblks - 1) * SM4_BLOCK_SIZE);
262 			memcpy(walk.iv, src + (nblks - 1) * SM4_BLOCK_SIZE,
263 				SM4_BLOCK_SIZE);
264 
265 			sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
266 					keystream, nblks);
267 
268 			crypto_xor_cpy(dst, src, keystream,
269 					nblks * SM4_BLOCK_SIZE);
270 			dst += nblks * SM4_BLOCK_SIZE;
271 			src += nblks * SM4_BLOCK_SIZE;
272 			nbytes -= nblks * SM4_BLOCK_SIZE;
273 		}
274 
275 		kernel_neon_end();
276 
277 		/* tail */
278 		if (walk.nbytes == walk.total && nbytes > 0) {
279 			u8 keystream[SM4_BLOCK_SIZE];
280 
281 			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
282 			crypto_xor_cpy(dst, src, keystream, nbytes);
283 			nbytes = 0;
284 		}
285 
286 		err = skcipher_walk_done(&walk, nbytes);
287 	}
288 
289 	return err;
290 }
291 
292 static int sm4_ctr_crypt(struct skcipher_request *req)
293 {
294 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
295 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
296 	struct skcipher_walk walk;
297 	unsigned int nbytes;
298 	int err;
299 
300 	err = skcipher_walk_virt(&walk, req, false);
301 
302 	while ((nbytes = walk.nbytes) > 0) {
303 		const u8 *src = walk.src.virt.addr;
304 		u8 *dst = walk.dst.virt.addr;
305 		unsigned int nblks;
306 
307 		kernel_neon_begin();
308 
309 		nblks = BYTES2BLK8(nbytes);
310 		if (nblks) {
311 			sm4_neon_ctr_enc_blk8(ctx->rkey_enc, dst, src,
312 					walk.iv, nblks);
313 			dst += nblks * SM4_BLOCK_SIZE;
314 			src += nblks * SM4_BLOCK_SIZE;
315 			nbytes -= nblks * SM4_BLOCK_SIZE;
316 		}
317 
318 		nblks = BYTES2BLKS(nbytes);
319 		if (nblks) {
320 			u8 keystream[SM4_BLOCK_SIZE * 8];
321 			int i;
322 
323 			for (i = 0; i < nblks; i++) {
324 				memcpy(&keystream[i * SM4_BLOCK_SIZE],
325 					walk.iv, SM4_BLOCK_SIZE);
326 				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
327 			}
328 			sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
329 					keystream, nblks);
330 
331 			crypto_xor_cpy(dst, src, keystream,
332 					nblks * SM4_BLOCK_SIZE);
333 			dst += nblks * SM4_BLOCK_SIZE;
334 			src += nblks * SM4_BLOCK_SIZE;
335 			nbytes -= nblks * SM4_BLOCK_SIZE;
336 		}
337 
338 		kernel_neon_end();
339 
340 		/* tail */
341 		if (walk.nbytes == walk.total && nbytes > 0) {
342 			u8 keystream[SM4_BLOCK_SIZE];
343 
344 			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
345 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
346 			crypto_xor_cpy(dst, src, keystream, nbytes);
347 			nbytes = 0;
348 		}
349 
350 		err = skcipher_walk_done(&walk, nbytes);
351 	}
352 
353 	return err;
354 }
355 
356 static struct skcipher_alg sm4_algs[] = {
357 	{
358 		.base = {
359 			.cra_name		= "ecb(sm4)",
360 			.cra_driver_name	= "ecb-sm4-neon",
361 			.cra_priority		= 200,
362 			.cra_blocksize		= SM4_BLOCK_SIZE,
363 			.cra_ctxsize		= sizeof(struct sm4_ctx),
364 			.cra_module		= THIS_MODULE,
365 		},
366 		.min_keysize	= SM4_KEY_SIZE,
367 		.max_keysize	= SM4_KEY_SIZE,
368 		.setkey		= sm4_setkey,
369 		.encrypt	= sm4_ecb_encrypt,
370 		.decrypt	= sm4_ecb_decrypt,
371 	}, {
372 		.base = {
373 			.cra_name		= "cbc(sm4)",
374 			.cra_driver_name	= "cbc-sm4-neon",
375 			.cra_priority		= 200,
376 			.cra_blocksize		= SM4_BLOCK_SIZE,
377 			.cra_ctxsize		= sizeof(struct sm4_ctx),
378 			.cra_module		= THIS_MODULE,
379 		},
380 		.min_keysize	= SM4_KEY_SIZE,
381 		.max_keysize	= SM4_KEY_SIZE,
382 		.ivsize		= SM4_BLOCK_SIZE,
383 		.setkey		= sm4_setkey,
384 		.encrypt	= sm4_cbc_encrypt,
385 		.decrypt	= sm4_cbc_decrypt,
386 	}, {
387 		.base = {
388 			.cra_name		= "cfb(sm4)",
389 			.cra_driver_name	= "cfb-sm4-neon",
390 			.cra_priority		= 200,
391 			.cra_blocksize		= 1,
392 			.cra_ctxsize		= sizeof(struct sm4_ctx),
393 			.cra_module		= THIS_MODULE,
394 		},
395 		.min_keysize	= SM4_KEY_SIZE,
396 		.max_keysize	= SM4_KEY_SIZE,
397 		.ivsize		= SM4_BLOCK_SIZE,
398 		.chunksize	= SM4_BLOCK_SIZE,
399 		.setkey		= sm4_setkey,
400 		.encrypt	= sm4_cfb_encrypt,
401 		.decrypt	= sm4_cfb_decrypt,
402 	}, {
403 		.base = {
404 			.cra_name		= "ctr(sm4)",
405 			.cra_driver_name	= "ctr-sm4-neon",
406 			.cra_priority		= 200,
407 			.cra_blocksize		= 1,
408 			.cra_ctxsize		= sizeof(struct sm4_ctx),
409 			.cra_module		= THIS_MODULE,
410 		},
411 		.min_keysize	= SM4_KEY_SIZE,
412 		.max_keysize	= SM4_KEY_SIZE,
413 		.ivsize		= SM4_BLOCK_SIZE,
414 		.chunksize	= SM4_BLOCK_SIZE,
415 		.setkey		= sm4_setkey,
416 		.encrypt	= sm4_ctr_crypt,
417 		.decrypt	= sm4_ctr_crypt,
418 	}
419 };
420 
421 static int __init sm4_init(void)
422 {
423 	return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
424 }
425 
426 static void __exit sm4_exit(void)
427 {
428 	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
429 }
430 
431 module_init(sm4_init);
432 module_exit(sm4_exit);
433 
434 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 NEON");
435 MODULE_ALIAS_CRYPTO("sm4-neon");
436 MODULE_ALIAS_CRYPTO("sm4");
437 MODULE_ALIAS_CRYPTO("ecb(sm4)");
438 MODULE_ALIAS_CRYPTO("cbc(sm4)");
439 MODULE_ALIAS_CRYPTO("cfb(sm4)");
440 MODULE_ALIAS_CRYPTO("ctr(sm4)");
441 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
442 MODULE_LICENSE("GPL v2");
443