xref: /openbmc/linux/arch/arm64/crypto/sm4-ce-glue.c (revision 61c1f340bc809a1ca1e3c8794207a91cde1a7c78)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (C) 2022, Alibaba Group.
8  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <linux/cpufeature.h>
15 #include <asm/neon.h>
16 #include <asm/simd.h>
17 #include <crypto/internal/simd.h>
18 #include <crypto/internal/skcipher.h>
19 #include <crypto/sm4.h>
20 
21 #define BYTES2BLKS(nbytes)	((nbytes) >> 4)
22 
23 asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
24 				  const u32 *fk, const u32 *ck);
25 asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
26 asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
27 			     unsigned int nblks);
28 asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
29 			       u8 *iv, unsigned int nblks);
30 asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
31 			       u8 *iv, unsigned int nblks);
32 asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
33 			       u8 *iv, unsigned int nblks);
34 asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
35 			       u8 *iv, unsigned int nblks);
36 asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
37 			       u8 *iv, unsigned int nblks);
38 
39 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
40 		      unsigned int key_len)
41 {
42 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
43 
44 	if (key_len != SM4_KEY_SIZE)
45 		return -EINVAL;
46 
47 	sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
48 			  crypto_sm4_fk, crypto_sm4_ck);
49 	return 0;
50 }
51 
52 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
53 {
54 	struct skcipher_walk walk;
55 	unsigned int nbytes;
56 	int err;
57 
58 	err = skcipher_walk_virt(&walk, req, false);
59 
60 	while ((nbytes = walk.nbytes) > 0) {
61 		const u8 *src = walk.src.virt.addr;
62 		u8 *dst = walk.dst.virt.addr;
63 		unsigned int nblks;
64 
65 		kernel_neon_begin();
66 
67 		nblks = BYTES2BLKS(nbytes);
68 		if (nblks) {
69 			sm4_ce_crypt(rkey, dst, src, nblks);
70 			nbytes -= nblks * SM4_BLOCK_SIZE;
71 		}
72 
73 		kernel_neon_end();
74 
75 		err = skcipher_walk_done(&walk, nbytes);
76 	}
77 
78 	return err;
79 }
80 
81 static int sm4_ecb_encrypt(struct skcipher_request *req)
82 {
83 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
84 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
85 
86 	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
87 }
88 
89 static int sm4_ecb_decrypt(struct skcipher_request *req)
90 {
91 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
92 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
93 
94 	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
95 }
96 
97 static int sm4_cbc_encrypt(struct skcipher_request *req)
98 {
99 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
100 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
101 	struct skcipher_walk walk;
102 	unsigned int nbytes;
103 	int err;
104 
105 	err = skcipher_walk_virt(&walk, req, false);
106 
107 	while ((nbytes = walk.nbytes) > 0) {
108 		const u8 *src = walk.src.virt.addr;
109 		u8 *dst = walk.dst.virt.addr;
110 		unsigned int nblks;
111 
112 		kernel_neon_begin();
113 
114 		nblks = BYTES2BLKS(nbytes);
115 		if (nblks) {
116 			sm4_ce_cbc_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
117 			nbytes -= nblks * SM4_BLOCK_SIZE;
118 		}
119 
120 		kernel_neon_end();
121 
122 		err = skcipher_walk_done(&walk, nbytes);
123 	}
124 
125 	return err;
126 }
127 
128 static int sm4_cbc_decrypt(struct skcipher_request *req)
129 {
130 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
131 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
132 	struct skcipher_walk walk;
133 	unsigned int nbytes;
134 	int err;
135 
136 	err = skcipher_walk_virt(&walk, req, false);
137 
138 	while ((nbytes = walk.nbytes) > 0) {
139 		const u8 *src = walk.src.virt.addr;
140 		u8 *dst = walk.dst.virt.addr;
141 		unsigned int nblks;
142 
143 		kernel_neon_begin();
144 
145 		nblks = BYTES2BLKS(nbytes);
146 		if (nblks) {
147 			sm4_ce_cbc_dec(ctx->rkey_dec, dst, src, walk.iv, nblks);
148 			nbytes -= nblks * SM4_BLOCK_SIZE;
149 		}
150 
151 		kernel_neon_end();
152 
153 		err = skcipher_walk_done(&walk, nbytes);
154 	}
155 
156 	return err;
157 }
158 
159 static int sm4_cfb_encrypt(struct skcipher_request *req)
160 {
161 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
162 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
163 	struct skcipher_walk walk;
164 	unsigned int nbytes;
165 	int err;
166 
167 	err = skcipher_walk_virt(&walk, req, false);
168 
169 	while ((nbytes = walk.nbytes) > 0) {
170 		const u8 *src = walk.src.virt.addr;
171 		u8 *dst = walk.dst.virt.addr;
172 		unsigned int nblks;
173 
174 		kernel_neon_begin();
175 
176 		nblks = BYTES2BLKS(nbytes);
177 		if (nblks) {
178 			sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
179 			dst += nblks * SM4_BLOCK_SIZE;
180 			src += nblks * SM4_BLOCK_SIZE;
181 			nbytes -= nblks * SM4_BLOCK_SIZE;
182 		}
183 
184 		/* tail */
185 		if (walk.nbytes == walk.total && nbytes > 0) {
186 			u8 keystream[SM4_BLOCK_SIZE];
187 
188 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
189 			crypto_xor_cpy(dst, src, keystream, nbytes);
190 			nbytes = 0;
191 		}
192 
193 		kernel_neon_end();
194 
195 		err = skcipher_walk_done(&walk, nbytes);
196 	}
197 
198 	return err;
199 }
200 
201 static int sm4_cfb_decrypt(struct skcipher_request *req)
202 {
203 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
204 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
205 	struct skcipher_walk walk;
206 	unsigned int nbytes;
207 	int err;
208 
209 	err = skcipher_walk_virt(&walk, req, false);
210 
211 	while ((nbytes = walk.nbytes) > 0) {
212 		const u8 *src = walk.src.virt.addr;
213 		u8 *dst = walk.dst.virt.addr;
214 		unsigned int nblks;
215 
216 		kernel_neon_begin();
217 
218 		nblks = BYTES2BLKS(nbytes);
219 		if (nblks) {
220 			sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks);
221 			dst += nblks * SM4_BLOCK_SIZE;
222 			src += nblks * SM4_BLOCK_SIZE;
223 			nbytes -= nblks * SM4_BLOCK_SIZE;
224 		}
225 
226 		/* tail */
227 		if (walk.nbytes == walk.total && nbytes > 0) {
228 			u8 keystream[SM4_BLOCK_SIZE];
229 
230 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
231 			crypto_xor_cpy(dst, src, keystream, nbytes);
232 			nbytes = 0;
233 		}
234 
235 		kernel_neon_end();
236 
237 		err = skcipher_walk_done(&walk, nbytes);
238 	}
239 
240 	return err;
241 }
242 
243 static int sm4_ctr_crypt(struct skcipher_request *req)
244 {
245 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
246 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
247 	struct skcipher_walk walk;
248 	unsigned int nbytes;
249 	int err;
250 
251 	err = skcipher_walk_virt(&walk, req, false);
252 
253 	while ((nbytes = walk.nbytes) > 0) {
254 		const u8 *src = walk.src.virt.addr;
255 		u8 *dst = walk.dst.virt.addr;
256 		unsigned int nblks;
257 
258 		kernel_neon_begin();
259 
260 		nblks = BYTES2BLKS(nbytes);
261 		if (nblks) {
262 			sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
263 			dst += nblks * SM4_BLOCK_SIZE;
264 			src += nblks * SM4_BLOCK_SIZE;
265 			nbytes -= nblks * SM4_BLOCK_SIZE;
266 		}
267 
268 		/* tail */
269 		if (walk.nbytes == walk.total && nbytes > 0) {
270 			u8 keystream[SM4_BLOCK_SIZE];
271 
272 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
273 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
274 			crypto_xor_cpy(dst, src, keystream, nbytes);
275 			nbytes = 0;
276 		}
277 
278 		kernel_neon_end();
279 
280 		err = skcipher_walk_done(&walk, nbytes);
281 	}
282 
283 	return err;
284 }
285 
286 static struct skcipher_alg sm4_algs[] = {
287 	{
288 		.base = {
289 			.cra_name		= "ecb(sm4)",
290 			.cra_driver_name	= "ecb-sm4-ce",
291 			.cra_priority		= 400,
292 			.cra_blocksize		= SM4_BLOCK_SIZE,
293 			.cra_ctxsize		= sizeof(struct sm4_ctx),
294 			.cra_module		= THIS_MODULE,
295 		},
296 		.min_keysize	= SM4_KEY_SIZE,
297 		.max_keysize	= SM4_KEY_SIZE,
298 		.setkey		= sm4_setkey,
299 		.encrypt	= sm4_ecb_encrypt,
300 		.decrypt	= sm4_ecb_decrypt,
301 	}, {
302 		.base = {
303 			.cra_name		= "cbc(sm4)",
304 			.cra_driver_name	= "cbc-sm4-ce",
305 			.cra_priority		= 400,
306 			.cra_blocksize		= SM4_BLOCK_SIZE,
307 			.cra_ctxsize		= sizeof(struct sm4_ctx),
308 			.cra_module		= THIS_MODULE,
309 		},
310 		.min_keysize	= SM4_KEY_SIZE,
311 		.max_keysize	= SM4_KEY_SIZE,
312 		.ivsize		= SM4_BLOCK_SIZE,
313 		.setkey		= sm4_setkey,
314 		.encrypt	= sm4_cbc_encrypt,
315 		.decrypt	= sm4_cbc_decrypt,
316 	}, {
317 		.base = {
318 			.cra_name		= "cfb(sm4)",
319 			.cra_driver_name	= "cfb-sm4-ce",
320 			.cra_priority		= 400,
321 			.cra_blocksize		= 1,
322 			.cra_ctxsize		= sizeof(struct sm4_ctx),
323 			.cra_module		= THIS_MODULE,
324 		},
325 		.min_keysize	= SM4_KEY_SIZE,
326 		.max_keysize	= SM4_KEY_SIZE,
327 		.ivsize		= SM4_BLOCK_SIZE,
328 		.chunksize	= SM4_BLOCK_SIZE,
329 		.setkey		= sm4_setkey,
330 		.encrypt	= sm4_cfb_encrypt,
331 		.decrypt	= sm4_cfb_decrypt,
332 	}, {
333 		.base = {
334 			.cra_name		= "ctr(sm4)",
335 			.cra_driver_name	= "ctr-sm4-ce",
336 			.cra_priority		= 400,
337 			.cra_blocksize		= 1,
338 			.cra_ctxsize		= sizeof(struct sm4_ctx),
339 			.cra_module		= THIS_MODULE,
340 		},
341 		.min_keysize	= SM4_KEY_SIZE,
342 		.max_keysize	= SM4_KEY_SIZE,
343 		.ivsize		= SM4_BLOCK_SIZE,
344 		.chunksize	= SM4_BLOCK_SIZE,
345 		.setkey		= sm4_setkey,
346 		.encrypt	= sm4_ctr_crypt,
347 		.decrypt	= sm4_ctr_crypt,
348 	}
349 };
350 
351 static int __init sm4_init(void)
352 {
353 	return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
354 }
355 
356 static void __exit sm4_exit(void)
357 {
358 	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
359 }
360 
361 module_cpu_feature_match(SM4, sm4_init);
362 module_exit(sm4_exit);
363 
364 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 Crypto Extensions");
365 MODULE_ALIAS_CRYPTO("sm4-ce");
366 MODULE_ALIAS_CRYPTO("sm4");
367 MODULE_ALIAS_CRYPTO("ecb(sm4)");
368 MODULE_ALIAS_CRYPTO("cbc(sm4)");
369 MODULE_ALIAS_CRYPTO("cfb(sm4)");
370 MODULE_ALIAS_CRYPTO("ctr(sm4)");
371 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
372 MODULE_LICENSE("GPL v2");
373