xref: /openbmc/linux/arch/arm64/crypto/sm4-ce-glue.c (revision ce41fefd)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (C) 2022, Alibaba Group.
8  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <linux/cpufeature.h>
15 #include <asm/neon.h>
16 #include <asm/simd.h>
17 #include <crypto/internal/simd.h>
18 #include <crypto/internal/skcipher.h>
19 #include <crypto/sm4.h>
20 
21 #define BYTES2BLKS(nbytes)	((nbytes) >> 4)
22 
23 asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
24 				  const u32 *fk, const u32 *ck);
25 asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
26 asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
27 			     unsigned int nblks);
28 asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
29 			       u8 *iv, unsigned int nblocks);
30 asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
31 			       u8 *iv, unsigned int nblocks);
32 asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
33 			       u8 *iv, unsigned int nblks);
34 asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
35 			       u8 *iv, unsigned int nblks);
36 asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
37 			       u8 *iv, unsigned int nblks);
38 
39 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
40 		      unsigned int key_len)
41 {
42 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
43 
44 	if (key_len != SM4_KEY_SIZE)
45 		return -EINVAL;
46 
47 	sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
48 			  crypto_sm4_fk, crypto_sm4_ck);
49 	return 0;
50 }
51 
52 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
53 {
54 	struct skcipher_walk walk;
55 	unsigned int nbytes;
56 	int err;
57 
58 	err = skcipher_walk_virt(&walk, req, false);
59 
60 	while ((nbytes = walk.nbytes) > 0) {
61 		const u8 *src = walk.src.virt.addr;
62 		u8 *dst = walk.dst.virt.addr;
63 		unsigned int nblks;
64 
65 		kernel_neon_begin();
66 
67 		nblks = BYTES2BLKS(nbytes);
68 		if (nblks) {
69 			sm4_ce_crypt(rkey, dst, src, nblks);
70 			nbytes -= nblks * SM4_BLOCK_SIZE;
71 		}
72 
73 		kernel_neon_end();
74 
75 		err = skcipher_walk_done(&walk, nbytes);
76 	}
77 
78 	return err;
79 }
80 
81 static int sm4_ecb_encrypt(struct skcipher_request *req)
82 {
83 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
84 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
85 
86 	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
87 }
88 
89 static int sm4_ecb_decrypt(struct skcipher_request *req)
90 {
91 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
92 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
93 
94 	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
95 }
96 
97 static int sm4_cbc_crypt(struct skcipher_request *req,
98 			 struct sm4_ctx *ctx, bool encrypt)
99 {
100 	struct skcipher_walk walk;
101 	unsigned int nbytes;
102 	int err;
103 
104 	err = skcipher_walk_virt(&walk, req, false);
105 	if (err)
106 		return err;
107 
108 	while ((nbytes = walk.nbytes) > 0) {
109 		const u8 *src = walk.src.virt.addr;
110 		u8 *dst = walk.dst.virt.addr;
111 		unsigned int nblocks;
112 
113 		nblocks = nbytes / SM4_BLOCK_SIZE;
114 		if (nblocks) {
115 			kernel_neon_begin();
116 
117 			if (encrypt)
118 				sm4_ce_cbc_enc(ctx->rkey_enc, dst, src,
119 					       walk.iv, nblocks);
120 			else
121 				sm4_ce_cbc_dec(ctx->rkey_dec, dst, src,
122 					       walk.iv, nblocks);
123 
124 			kernel_neon_end();
125 		}
126 
127 		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
128 	}
129 
130 	return err;
131 }
132 
133 static int sm4_cbc_encrypt(struct skcipher_request *req)
134 {
135 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
136 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
137 
138 	return sm4_cbc_crypt(req, ctx, true);
139 }
140 
141 static int sm4_cbc_decrypt(struct skcipher_request *req)
142 {
143 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
144 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
145 
146 	return sm4_cbc_crypt(req, ctx, false);
147 }
148 
149 static int sm4_cfb_encrypt(struct skcipher_request *req)
150 {
151 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
152 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
153 	struct skcipher_walk walk;
154 	unsigned int nbytes;
155 	int err;
156 
157 	err = skcipher_walk_virt(&walk, req, false);
158 
159 	while ((nbytes = walk.nbytes) > 0) {
160 		const u8 *src = walk.src.virt.addr;
161 		u8 *dst = walk.dst.virt.addr;
162 		unsigned int nblks;
163 
164 		kernel_neon_begin();
165 
166 		nblks = BYTES2BLKS(nbytes);
167 		if (nblks) {
168 			sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
169 			dst += nblks * SM4_BLOCK_SIZE;
170 			src += nblks * SM4_BLOCK_SIZE;
171 			nbytes -= nblks * SM4_BLOCK_SIZE;
172 		}
173 
174 		/* tail */
175 		if (walk.nbytes == walk.total && nbytes > 0) {
176 			u8 keystream[SM4_BLOCK_SIZE];
177 
178 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
179 			crypto_xor_cpy(dst, src, keystream, nbytes);
180 			nbytes = 0;
181 		}
182 
183 		kernel_neon_end();
184 
185 		err = skcipher_walk_done(&walk, nbytes);
186 	}
187 
188 	return err;
189 }
190 
191 static int sm4_cfb_decrypt(struct skcipher_request *req)
192 {
193 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195 	struct skcipher_walk walk;
196 	unsigned int nbytes;
197 	int err;
198 
199 	err = skcipher_walk_virt(&walk, req, false);
200 
201 	while ((nbytes = walk.nbytes) > 0) {
202 		const u8 *src = walk.src.virt.addr;
203 		u8 *dst = walk.dst.virt.addr;
204 		unsigned int nblks;
205 
206 		kernel_neon_begin();
207 
208 		nblks = BYTES2BLKS(nbytes);
209 		if (nblks) {
210 			sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks);
211 			dst += nblks * SM4_BLOCK_SIZE;
212 			src += nblks * SM4_BLOCK_SIZE;
213 			nbytes -= nblks * SM4_BLOCK_SIZE;
214 		}
215 
216 		/* tail */
217 		if (walk.nbytes == walk.total && nbytes > 0) {
218 			u8 keystream[SM4_BLOCK_SIZE];
219 
220 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
221 			crypto_xor_cpy(dst, src, keystream, nbytes);
222 			nbytes = 0;
223 		}
224 
225 		kernel_neon_end();
226 
227 		err = skcipher_walk_done(&walk, nbytes);
228 	}
229 
230 	return err;
231 }
232 
233 static int sm4_ctr_crypt(struct skcipher_request *req)
234 {
235 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237 	struct skcipher_walk walk;
238 	unsigned int nbytes;
239 	int err;
240 
241 	err = skcipher_walk_virt(&walk, req, false);
242 
243 	while ((nbytes = walk.nbytes) > 0) {
244 		const u8 *src = walk.src.virt.addr;
245 		u8 *dst = walk.dst.virt.addr;
246 		unsigned int nblks;
247 
248 		kernel_neon_begin();
249 
250 		nblks = BYTES2BLKS(nbytes);
251 		if (nblks) {
252 			sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
253 			dst += nblks * SM4_BLOCK_SIZE;
254 			src += nblks * SM4_BLOCK_SIZE;
255 			nbytes -= nblks * SM4_BLOCK_SIZE;
256 		}
257 
258 		/* tail */
259 		if (walk.nbytes == walk.total && nbytes > 0) {
260 			u8 keystream[SM4_BLOCK_SIZE];
261 
262 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
263 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
264 			crypto_xor_cpy(dst, src, keystream, nbytes);
265 			nbytes = 0;
266 		}
267 
268 		kernel_neon_end();
269 
270 		err = skcipher_walk_done(&walk, nbytes);
271 	}
272 
273 	return err;
274 }
275 
276 static struct skcipher_alg sm4_algs[] = {
277 	{
278 		.base = {
279 			.cra_name		= "ecb(sm4)",
280 			.cra_driver_name	= "ecb-sm4-ce",
281 			.cra_priority		= 400,
282 			.cra_blocksize		= SM4_BLOCK_SIZE,
283 			.cra_ctxsize		= sizeof(struct sm4_ctx),
284 			.cra_module		= THIS_MODULE,
285 		},
286 		.min_keysize	= SM4_KEY_SIZE,
287 		.max_keysize	= SM4_KEY_SIZE,
288 		.setkey		= sm4_setkey,
289 		.encrypt	= sm4_ecb_encrypt,
290 		.decrypt	= sm4_ecb_decrypt,
291 	}, {
292 		.base = {
293 			.cra_name		= "cbc(sm4)",
294 			.cra_driver_name	= "cbc-sm4-ce",
295 			.cra_priority		= 400,
296 			.cra_blocksize		= SM4_BLOCK_SIZE,
297 			.cra_ctxsize		= sizeof(struct sm4_ctx),
298 			.cra_module		= THIS_MODULE,
299 		},
300 		.min_keysize	= SM4_KEY_SIZE,
301 		.max_keysize	= SM4_KEY_SIZE,
302 		.ivsize		= SM4_BLOCK_SIZE,
303 		.setkey		= sm4_setkey,
304 		.encrypt	= sm4_cbc_encrypt,
305 		.decrypt	= sm4_cbc_decrypt,
306 	}, {
307 		.base = {
308 			.cra_name		= "cfb(sm4)",
309 			.cra_driver_name	= "cfb-sm4-ce",
310 			.cra_priority		= 400,
311 			.cra_blocksize		= 1,
312 			.cra_ctxsize		= sizeof(struct sm4_ctx),
313 			.cra_module		= THIS_MODULE,
314 		},
315 		.min_keysize	= SM4_KEY_SIZE,
316 		.max_keysize	= SM4_KEY_SIZE,
317 		.ivsize		= SM4_BLOCK_SIZE,
318 		.chunksize	= SM4_BLOCK_SIZE,
319 		.setkey		= sm4_setkey,
320 		.encrypt	= sm4_cfb_encrypt,
321 		.decrypt	= sm4_cfb_decrypt,
322 	}, {
323 		.base = {
324 			.cra_name		= "ctr(sm4)",
325 			.cra_driver_name	= "ctr-sm4-ce",
326 			.cra_priority		= 400,
327 			.cra_blocksize		= 1,
328 			.cra_ctxsize		= sizeof(struct sm4_ctx),
329 			.cra_module		= THIS_MODULE,
330 		},
331 		.min_keysize	= SM4_KEY_SIZE,
332 		.max_keysize	= SM4_KEY_SIZE,
333 		.ivsize		= SM4_BLOCK_SIZE,
334 		.chunksize	= SM4_BLOCK_SIZE,
335 		.setkey		= sm4_setkey,
336 		.encrypt	= sm4_ctr_crypt,
337 		.decrypt	= sm4_ctr_crypt,
338 	}
339 };
340 
341 static int __init sm4_init(void)
342 {
343 	return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
344 }
345 
346 static void __exit sm4_exit(void)
347 {
348 	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
349 }
350 
351 module_cpu_feature_match(SM4, sm4_init);
352 module_exit(sm4_exit);
353 
354 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 Crypto Extensions");
355 MODULE_ALIAS_CRYPTO("sm4-ce");
356 MODULE_ALIAS_CRYPTO("sm4");
357 MODULE_ALIAS_CRYPTO("ecb(sm4)");
358 MODULE_ALIAS_CRYPTO("cbc(sm4)");
359 MODULE_ALIAS_CRYPTO("cfb(sm4)");
360 MODULE_ALIAS_CRYPTO("ctr(sm4)");
361 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
362 MODULE_LICENSE("GPL v2");
363