1a7ee22eeSTianjia Zhang /* SPDX-License-Identifier: GPL-2.0-or-later */
2a7ee22eeSTianjia Zhang /*
3a7ee22eeSTianjia Zhang  * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4a7ee22eeSTianjia Zhang  * as specified in
5a7ee22eeSTianjia Zhang  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6a7ee22eeSTianjia Zhang  *
7a7ee22eeSTianjia Zhang  * Copyright (c) 2021, Alibaba Group.
8a7ee22eeSTianjia Zhang  * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9a7ee22eeSTianjia Zhang  */
10a7ee22eeSTianjia Zhang 
11a7ee22eeSTianjia Zhang #include <linux/module.h>
12a7ee22eeSTianjia Zhang #include <linux/crypto.h>
13a7ee22eeSTianjia Zhang #include <linux/kernel.h>
14a7ee22eeSTianjia Zhang #include <asm/simd.h>
15a7ee22eeSTianjia Zhang #include <crypto/internal/simd.h>
16a7ee22eeSTianjia Zhang #include <crypto/internal/skcipher.h>
17a7ee22eeSTianjia Zhang #include <crypto/sm4.h>
18*de79d9aaSTianjia Zhang #include "sm4-avx.h"
19a7ee22eeSTianjia Zhang 
20a7ee22eeSTianjia Zhang #define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
21a7ee22eeSTianjia Zhang 
22a7ee22eeSTianjia Zhang asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23a7ee22eeSTianjia Zhang 				const u8 *src, int nblocks);
24a7ee22eeSTianjia Zhang asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25a7ee22eeSTianjia Zhang 				const u8 *src, int nblocks);
26a7ee22eeSTianjia Zhang asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27a7ee22eeSTianjia Zhang 				const u8 *src, u8 *iv);
28a7ee22eeSTianjia Zhang asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29a7ee22eeSTianjia Zhang 				const u8 *src, u8 *iv);
30a7ee22eeSTianjia Zhang asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
31a7ee22eeSTianjia Zhang 				const u8 *src, u8 *iv);
32a7ee22eeSTianjia Zhang 
sm4_skcipher_setkey(struct crypto_skcipher * tfm,const u8 * key,unsigned int key_len)33a7ee22eeSTianjia Zhang static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
34a7ee22eeSTianjia Zhang 			unsigned int key_len)
35a7ee22eeSTianjia Zhang {
36a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
37a7ee22eeSTianjia Zhang 
38a7ee22eeSTianjia Zhang 	return sm4_expandkey(ctx, key, key_len);
39a7ee22eeSTianjia Zhang }
40a7ee22eeSTianjia Zhang 
ecb_do_crypt(struct skcipher_request * req,const u32 * rkey)41a7ee22eeSTianjia Zhang static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
42a7ee22eeSTianjia Zhang {
43a7ee22eeSTianjia Zhang 	struct skcipher_walk walk;
44a7ee22eeSTianjia Zhang 	unsigned int nbytes;
45a7ee22eeSTianjia Zhang 	int err;
46a7ee22eeSTianjia Zhang 
47a7ee22eeSTianjia Zhang 	err = skcipher_walk_virt(&walk, req, false);
48a7ee22eeSTianjia Zhang 
49a7ee22eeSTianjia Zhang 	while ((nbytes = walk.nbytes) > 0) {
50a7ee22eeSTianjia Zhang 		const u8 *src = walk.src.virt.addr;
51a7ee22eeSTianjia Zhang 		u8 *dst = walk.dst.virt.addr;
52a7ee22eeSTianjia Zhang 
53a7ee22eeSTianjia Zhang 		kernel_fpu_begin();
54a7ee22eeSTianjia Zhang 		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
55a7ee22eeSTianjia Zhang 			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
56a7ee22eeSTianjia Zhang 			dst += SM4_CRYPT8_BLOCK_SIZE;
57a7ee22eeSTianjia Zhang 			src += SM4_CRYPT8_BLOCK_SIZE;
58a7ee22eeSTianjia Zhang 			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
59a7ee22eeSTianjia Zhang 		}
60a7ee22eeSTianjia Zhang 		while (nbytes >= SM4_BLOCK_SIZE) {
61a7ee22eeSTianjia Zhang 			unsigned int nblocks = min(nbytes >> 4, 4u);
62a7ee22eeSTianjia Zhang 			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
63a7ee22eeSTianjia Zhang 			dst += nblocks * SM4_BLOCK_SIZE;
64a7ee22eeSTianjia Zhang 			src += nblocks * SM4_BLOCK_SIZE;
65a7ee22eeSTianjia Zhang 			nbytes -= nblocks * SM4_BLOCK_SIZE;
66a7ee22eeSTianjia Zhang 		}
67a7ee22eeSTianjia Zhang 		kernel_fpu_end();
68a7ee22eeSTianjia Zhang 
69a7ee22eeSTianjia Zhang 		err = skcipher_walk_done(&walk, nbytes);
70a7ee22eeSTianjia Zhang 	}
71a7ee22eeSTianjia Zhang 
72a7ee22eeSTianjia Zhang 	return err;
73a7ee22eeSTianjia Zhang }
74a7ee22eeSTianjia Zhang 
sm4_avx_ecb_encrypt(struct skcipher_request * req)75*de79d9aaSTianjia Zhang int sm4_avx_ecb_encrypt(struct skcipher_request *req)
76a7ee22eeSTianjia Zhang {
77a7ee22eeSTianjia Zhang 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
78a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
79a7ee22eeSTianjia Zhang 
80a7ee22eeSTianjia Zhang 	return ecb_do_crypt(req, ctx->rkey_enc);
81a7ee22eeSTianjia Zhang }
82*de79d9aaSTianjia Zhang EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
83a7ee22eeSTianjia Zhang 
sm4_avx_ecb_decrypt(struct skcipher_request * req)84*de79d9aaSTianjia Zhang int sm4_avx_ecb_decrypt(struct skcipher_request *req)
85a7ee22eeSTianjia Zhang {
86a7ee22eeSTianjia Zhang 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
87a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
88a7ee22eeSTianjia Zhang 
89a7ee22eeSTianjia Zhang 	return ecb_do_crypt(req, ctx->rkey_dec);
90a7ee22eeSTianjia Zhang }
91*de79d9aaSTianjia Zhang EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
92a7ee22eeSTianjia Zhang 
sm4_cbc_encrypt(struct skcipher_request * req)93*de79d9aaSTianjia Zhang int sm4_cbc_encrypt(struct skcipher_request *req)
94a7ee22eeSTianjia Zhang {
95a7ee22eeSTianjia Zhang 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
96a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
97a7ee22eeSTianjia Zhang 	struct skcipher_walk walk;
98a7ee22eeSTianjia Zhang 	unsigned int nbytes;
99a7ee22eeSTianjia Zhang 	int err;
100a7ee22eeSTianjia Zhang 
101a7ee22eeSTianjia Zhang 	err = skcipher_walk_virt(&walk, req, false);
102a7ee22eeSTianjia Zhang 
103a7ee22eeSTianjia Zhang 	while ((nbytes = walk.nbytes) > 0) {
104a7ee22eeSTianjia Zhang 		const u8 *iv = walk.iv;
105a7ee22eeSTianjia Zhang 		const u8 *src = walk.src.virt.addr;
106a7ee22eeSTianjia Zhang 		u8 *dst = walk.dst.virt.addr;
107a7ee22eeSTianjia Zhang 
108a7ee22eeSTianjia Zhang 		while (nbytes >= SM4_BLOCK_SIZE) {
109a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
110a7ee22eeSTianjia Zhang 			sm4_crypt_block(ctx->rkey_enc, dst, dst);
111a7ee22eeSTianjia Zhang 			iv = dst;
112a7ee22eeSTianjia Zhang 			src += SM4_BLOCK_SIZE;
113a7ee22eeSTianjia Zhang 			dst += SM4_BLOCK_SIZE;
114a7ee22eeSTianjia Zhang 			nbytes -= SM4_BLOCK_SIZE;
115a7ee22eeSTianjia Zhang 		}
116a7ee22eeSTianjia Zhang 		if (iv != walk.iv)
117a7ee22eeSTianjia Zhang 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
118a7ee22eeSTianjia Zhang 
119a7ee22eeSTianjia Zhang 		err = skcipher_walk_done(&walk, nbytes);
120a7ee22eeSTianjia Zhang 	}
121a7ee22eeSTianjia Zhang 
122a7ee22eeSTianjia Zhang 	return err;
123a7ee22eeSTianjia Zhang }
124*de79d9aaSTianjia Zhang EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
125a7ee22eeSTianjia Zhang 
sm4_avx_cbc_decrypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)126*de79d9aaSTianjia Zhang int sm4_avx_cbc_decrypt(struct skcipher_request *req,
127*de79d9aaSTianjia Zhang 			unsigned int bsize, sm4_crypt_func func)
128a7ee22eeSTianjia Zhang {
129a7ee22eeSTianjia Zhang 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
130a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
131a7ee22eeSTianjia Zhang 	struct skcipher_walk walk;
132a7ee22eeSTianjia Zhang 	unsigned int nbytes;
133a7ee22eeSTianjia Zhang 	int err;
134a7ee22eeSTianjia Zhang 
135a7ee22eeSTianjia Zhang 	err = skcipher_walk_virt(&walk, req, false);
136a7ee22eeSTianjia Zhang 
137a7ee22eeSTianjia Zhang 	while ((nbytes = walk.nbytes) > 0) {
138a7ee22eeSTianjia Zhang 		const u8 *src = walk.src.virt.addr;
139a7ee22eeSTianjia Zhang 		u8 *dst = walk.dst.virt.addr;
140a7ee22eeSTianjia Zhang 
141a7ee22eeSTianjia Zhang 		kernel_fpu_begin();
142a7ee22eeSTianjia Zhang 
143*de79d9aaSTianjia Zhang 		while (nbytes >= bsize) {
144*de79d9aaSTianjia Zhang 			func(ctx->rkey_dec, dst, src, walk.iv);
145*de79d9aaSTianjia Zhang 			dst += bsize;
146*de79d9aaSTianjia Zhang 			src += bsize;
147*de79d9aaSTianjia Zhang 			nbytes -= bsize;
148a7ee22eeSTianjia Zhang 		}
149a7ee22eeSTianjia Zhang 
150*de79d9aaSTianjia Zhang 		while (nbytes >= SM4_BLOCK_SIZE) {
151a7ee22eeSTianjia Zhang 			u8 keystream[SM4_BLOCK_SIZE * 8];
152a7ee22eeSTianjia Zhang 			u8 iv[SM4_BLOCK_SIZE];
153a7ee22eeSTianjia Zhang 			unsigned int nblocks = min(nbytes >> 4, 8u);
154a7ee22eeSTianjia Zhang 			int i;
155a7ee22eeSTianjia Zhang 
156a7ee22eeSTianjia Zhang 			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
157a7ee22eeSTianjia Zhang 						src, nblocks);
158a7ee22eeSTianjia Zhang 
159a7ee22eeSTianjia Zhang 			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
160a7ee22eeSTianjia Zhang 			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
161a7ee22eeSTianjia Zhang 			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
162a7ee22eeSTianjia Zhang 
163a7ee22eeSTianjia Zhang 			for (i = nblocks - 1; i > 0; i--) {
164a7ee22eeSTianjia Zhang 				crypto_xor_cpy(dst, src,
165a7ee22eeSTianjia Zhang 					&keystream[i * SM4_BLOCK_SIZE],
166a7ee22eeSTianjia Zhang 					SM4_BLOCK_SIZE);
167a7ee22eeSTianjia Zhang 				src -= SM4_BLOCK_SIZE;
168a7ee22eeSTianjia Zhang 				dst -= SM4_BLOCK_SIZE;
169a7ee22eeSTianjia Zhang 			}
170a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
171a7ee22eeSTianjia Zhang 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
172*de79d9aaSTianjia Zhang 			dst += nblocks * SM4_BLOCK_SIZE;
173*de79d9aaSTianjia Zhang 			src += (nblocks + 1) * SM4_BLOCK_SIZE;
174a7ee22eeSTianjia Zhang 			nbytes -= nblocks * SM4_BLOCK_SIZE;
175a7ee22eeSTianjia Zhang 		}
176a7ee22eeSTianjia Zhang 
177a7ee22eeSTianjia Zhang 		kernel_fpu_end();
178a7ee22eeSTianjia Zhang 		err = skcipher_walk_done(&walk, nbytes);
179a7ee22eeSTianjia Zhang 	}
180a7ee22eeSTianjia Zhang 
181a7ee22eeSTianjia Zhang 	return err;
182a7ee22eeSTianjia Zhang }
183*de79d9aaSTianjia Zhang EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
184a7ee22eeSTianjia Zhang 
cbc_decrypt(struct skcipher_request * req)185*de79d9aaSTianjia Zhang static int cbc_decrypt(struct skcipher_request *req)
186*de79d9aaSTianjia Zhang {
187*de79d9aaSTianjia Zhang 	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
188*de79d9aaSTianjia Zhang 				sm4_aesni_avx_cbc_dec_blk8);
189*de79d9aaSTianjia Zhang }
190*de79d9aaSTianjia Zhang 
sm4_cfb_encrypt(struct skcipher_request * req)191*de79d9aaSTianjia Zhang int sm4_cfb_encrypt(struct skcipher_request *req)
192a7ee22eeSTianjia Zhang {
193a7ee22eeSTianjia Zhang 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195a7ee22eeSTianjia Zhang 	struct skcipher_walk walk;
196a7ee22eeSTianjia Zhang 	unsigned int nbytes;
197a7ee22eeSTianjia Zhang 	int err;
198a7ee22eeSTianjia Zhang 
199a7ee22eeSTianjia Zhang 	err = skcipher_walk_virt(&walk, req, false);
200a7ee22eeSTianjia Zhang 
201a7ee22eeSTianjia Zhang 	while ((nbytes = walk.nbytes) > 0) {
202a7ee22eeSTianjia Zhang 		u8 keystream[SM4_BLOCK_SIZE];
203a7ee22eeSTianjia Zhang 		const u8 *iv = walk.iv;
204a7ee22eeSTianjia Zhang 		const u8 *src = walk.src.virt.addr;
205a7ee22eeSTianjia Zhang 		u8 *dst = walk.dst.virt.addr;
206a7ee22eeSTianjia Zhang 
207a7ee22eeSTianjia Zhang 		while (nbytes >= SM4_BLOCK_SIZE) {
208a7ee22eeSTianjia Zhang 			sm4_crypt_block(ctx->rkey_enc, keystream, iv);
209a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
210a7ee22eeSTianjia Zhang 			iv = dst;
211a7ee22eeSTianjia Zhang 			src += SM4_BLOCK_SIZE;
212a7ee22eeSTianjia Zhang 			dst += SM4_BLOCK_SIZE;
213a7ee22eeSTianjia Zhang 			nbytes -= SM4_BLOCK_SIZE;
214a7ee22eeSTianjia Zhang 		}
215a7ee22eeSTianjia Zhang 		if (iv != walk.iv)
216a7ee22eeSTianjia Zhang 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
217a7ee22eeSTianjia Zhang 
218a7ee22eeSTianjia Zhang 		/* tail */
219a7ee22eeSTianjia Zhang 		if (walk.nbytes == walk.total && nbytes > 0) {
220a7ee22eeSTianjia Zhang 			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
221a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, src, keystream, nbytes);
222a7ee22eeSTianjia Zhang 			nbytes = 0;
223a7ee22eeSTianjia Zhang 		}
224a7ee22eeSTianjia Zhang 
225a7ee22eeSTianjia Zhang 		err = skcipher_walk_done(&walk, nbytes);
226a7ee22eeSTianjia Zhang 	}
227a7ee22eeSTianjia Zhang 
228a7ee22eeSTianjia Zhang 	return err;
229a7ee22eeSTianjia Zhang }
230*de79d9aaSTianjia Zhang EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
231a7ee22eeSTianjia Zhang 
sm4_avx_cfb_decrypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)232*de79d9aaSTianjia Zhang int sm4_avx_cfb_decrypt(struct skcipher_request *req,
233*de79d9aaSTianjia Zhang 			unsigned int bsize, sm4_crypt_func func)
234a7ee22eeSTianjia Zhang {
235a7ee22eeSTianjia Zhang 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237a7ee22eeSTianjia Zhang 	struct skcipher_walk walk;
238a7ee22eeSTianjia Zhang 	unsigned int nbytes;
239a7ee22eeSTianjia Zhang 	int err;
240a7ee22eeSTianjia Zhang 
241a7ee22eeSTianjia Zhang 	err = skcipher_walk_virt(&walk, req, false);
242a7ee22eeSTianjia Zhang 
243a7ee22eeSTianjia Zhang 	while ((nbytes = walk.nbytes) > 0) {
244a7ee22eeSTianjia Zhang 		const u8 *src = walk.src.virt.addr;
245a7ee22eeSTianjia Zhang 		u8 *dst = walk.dst.virt.addr;
246a7ee22eeSTianjia Zhang 
247a7ee22eeSTianjia Zhang 		kernel_fpu_begin();
248a7ee22eeSTianjia Zhang 
249*de79d9aaSTianjia Zhang 		while (nbytes >= bsize) {
250*de79d9aaSTianjia Zhang 			func(ctx->rkey_enc, dst, src, walk.iv);
251*de79d9aaSTianjia Zhang 			dst += bsize;
252*de79d9aaSTianjia Zhang 			src += bsize;
253*de79d9aaSTianjia Zhang 			nbytes -= bsize;
254a7ee22eeSTianjia Zhang 		}
255a7ee22eeSTianjia Zhang 
256*de79d9aaSTianjia Zhang 		while (nbytes >= SM4_BLOCK_SIZE) {
257a7ee22eeSTianjia Zhang 			u8 keystream[SM4_BLOCK_SIZE * 8];
258a7ee22eeSTianjia Zhang 			unsigned int nblocks = min(nbytes >> 4, 8u);
259a7ee22eeSTianjia Zhang 
260a7ee22eeSTianjia Zhang 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
261a7ee22eeSTianjia Zhang 			if (nblocks > 1)
262a7ee22eeSTianjia Zhang 				memcpy(&keystream[SM4_BLOCK_SIZE], src,
263a7ee22eeSTianjia Zhang 					(nblocks - 1) * SM4_BLOCK_SIZE);
264a7ee22eeSTianjia Zhang 			memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
265a7ee22eeSTianjia Zhang 				SM4_BLOCK_SIZE);
266a7ee22eeSTianjia Zhang 
267a7ee22eeSTianjia Zhang 			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
268a7ee22eeSTianjia Zhang 						keystream, nblocks);
269a7ee22eeSTianjia Zhang 
270a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, src, keystream,
271a7ee22eeSTianjia Zhang 					nblocks * SM4_BLOCK_SIZE);
272a7ee22eeSTianjia Zhang 			dst += nblocks * SM4_BLOCK_SIZE;
273a7ee22eeSTianjia Zhang 			src += nblocks * SM4_BLOCK_SIZE;
274a7ee22eeSTianjia Zhang 			nbytes -= nblocks * SM4_BLOCK_SIZE;
275a7ee22eeSTianjia Zhang 		}
276a7ee22eeSTianjia Zhang 
277a7ee22eeSTianjia Zhang 		kernel_fpu_end();
278a7ee22eeSTianjia Zhang 
279a7ee22eeSTianjia Zhang 		/* tail */
280a7ee22eeSTianjia Zhang 		if (walk.nbytes == walk.total && nbytes > 0) {
281a7ee22eeSTianjia Zhang 			u8 keystream[SM4_BLOCK_SIZE];
282a7ee22eeSTianjia Zhang 
283a7ee22eeSTianjia Zhang 			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
284a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, src, keystream, nbytes);
285a7ee22eeSTianjia Zhang 			nbytes = 0;
286a7ee22eeSTianjia Zhang 		}
287a7ee22eeSTianjia Zhang 
288a7ee22eeSTianjia Zhang 		err = skcipher_walk_done(&walk, nbytes);
289a7ee22eeSTianjia Zhang 	}
290a7ee22eeSTianjia Zhang 
291a7ee22eeSTianjia Zhang 	return err;
292a7ee22eeSTianjia Zhang }
293*de79d9aaSTianjia Zhang EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
294a7ee22eeSTianjia Zhang 
cfb_decrypt(struct skcipher_request * req)295*de79d9aaSTianjia Zhang static int cfb_decrypt(struct skcipher_request *req)
296*de79d9aaSTianjia Zhang {
297*de79d9aaSTianjia Zhang 	return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
298*de79d9aaSTianjia Zhang 				sm4_aesni_avx_cfb_dec_blk8);
299*de79d9aaSTianjia Zhang }
300*de79d9aaSTianjia Zhang 
sm4_avx_ctr_crypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)301*de79d9aaSTianjia Zhang int sm4_avx_ctr_crypt(struct skcipher_request *req,
302*de79d9aaSTianjia Zhang 			unsigned int bsize, sm4_crypt_func func)
303a7ee22eeSTianjia Zhang {
304a7ee22eeSTianjia Zhang 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
305a7ee22eeSTianjia Zhang 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
306a7ee22eeSTianjia Zhang 	struct skcipher_walk walk;
307a7ee22eeSTianjia Zhang 	unsigned int nbytes;
308a7ee22eeSTianjia Zhang 	int err;
309a7ee22eeSTianjia Zhang 
310a7ee22eeSTianjia Zhang 	err = skcipher_walk_virt(&walk, req, false);
311a7ee22eeSTianjia Zhang 
312a7ee22eeSTianjia Zhang 	while ((nbytes = walk.nbytes) > 0) {
313a7ee22eeSTianjia Zhang 		const u8 *src = walk.src.virt.addr;
314a7ee22eeSTianjia Zhang 		u8 *dst = walk.dst.virt.addr;
315a7ee22eeSTianjia Zhang 
316a7ee22eeSTianjia Zhang 		kernel_fpu_begin();
317a7ee22eeSTianjia Zhang 
318*de79d9aaSTianjia Zhang 		while (nbytes >= bsize) {
319*de79d9aaSTianjia Zhang 			func(ctx->rkey_enc, dst, src, walk.iv);
320*de79d9aaSTianjia Zhang 			dst += bsize;
321*de79d9aaSTianjia Zhang 			src += bsize;
322*de79d9aaSTianjia Zhang 			nbytes -= bsize;
323a7ee22eeSTianjia Zhang 		}
324a7ee22eeSTianjia Zhang 
325*de79d9aaSTianjia Zhang 		while (nbytes >= SM4_BLOCK_SIZE) {
326a7ee22eeSTianjia Zhang 			u8 keystream[SM4_BLOCK_SIZE * 8];
327a7ee22eeSTianjia Zhang 			unsigned int nblocks = min(nbytes >> 4, 8u);
328a7ee22eeSTianjia Zhang 			int i;
329a7ee22eeSTianjia Zhang 
330a7ee22eeSTianjia Zhang 			for (i = 0; i < nblocks; i++) {
331a7ee22eeSTianjia Zhang 				memcpy(&keystream[i * SM4_BLOCK_SIZE],
332a7ee22eeSTianjia Zhang 					walk.iv, SM4_BLOCK_SIZE);
333a7ee22eeSTianjia Zhang 				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
334a7ee22eeSTianjia Zhang 			}
335a7ee22eeSTianjia Zhang 			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
336a7ee22eeSTianjia Zhang 					keystream, nblocks);
337a7ee22eeSTianjia Zhang 
338a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, src, keystream,
339a7ee22eeSTianjia Zhang 					nblocks * SM4_BLOCK_SIZE);
340a7ee22eeSTianjia Zhang 			dst += nblocks * SM4_BLOCK_SIZE;
341a7ee22eeSTianjia Zhang 			src += nblocks * SM4_BLOCK_SIZE;
342a7ee22eeSTianjia Zhang 			nbytes -= nblocks * SM4_BLOCK_SIZE;
343a7ee22eeSTianjia Zhang 		}
344a7ee22eeSTianjia Zhang 
345a7ee22eeSTianjia Zhang 		kernel_fpu_end();
346a7ee22eeSTianjia Zhang 
347a7ee22eeSTianjia Zhang 		/* tail */
348a7ee22eeSTianjia Zhang 		if (walk.nbytes == walk.total && nbytes > 0) {
349a7ee22eeSTianjia Zhang 			u8 keystream[SM4_BLOCK_SIZE];
350a7ee22eeSTianjia Zhang 
351a7ee22eeSTianjia Zhang 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
352a7ee22eeSTianjia Zhang 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
353a7ee22eeSTianjia Zhang 
354a7ee22eeSTianjia Zhang 			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
355a7ee22eeSTianjia Zhang 
356a7ee22eeSTianjia Zhang 			crypto_xor_cpy(dst, src, keystream, nbytes);
357a7ee22eeSTianjia Zhang 			dst += nbytes;
358a7ee22eeSTianjia Zhang 			src += nbytes;
359a7ee22eeSTianjia Zhang 			nbytes = 0;
360a7ee22eeSTianjia Zhang 		}
361a7ee22eeSTianjia Zhang 
362a7ee22eeSTianjia Zhang 		err = skcipher_walk_done(&walk, nbytes);
363a7ee22eeSTianjia Zhang 	}
364a7ee22eeSTianjia Zhang 
365a7ee22eeSTianjia Zhang 	return err;
366a7ee22eeSTianjia Zhang }
367*de79d9aaSTianjia Zhang EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
368*de79d9aaSTianjia Zhang 
ctr_crypt(struct skcipher_request * req)369*de79d9aaSTianjia Zhang static int ctr_crypt(struct skcipher_request *req)
370*de79d9aaSTianjia Zhang {
371*de79d9aaSTianjia Zhang 	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
372*de79d9aaSTianjia Zhang 				sm4_aesni_avx_ctr_enc_blk8);
373*de79d9aaSTianjia Zhang }
374a7ee22eeSTianjia Zhang 
375a7ee22eeSTianjia Zhang static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
376a7ee22eeSTianjia Zhang 	{
377a7ee22eeSTianjia Zhang 		.base = {
378a7ee22eeSTianjia Zhang 			.cra_name		= "__ecb(sm4)",
379a7ee22eeSTianjia Zhang 			.cra_driver_name	= "__ecb-sm4-aesni-avx",
380a7ee22eeSTianjia Zhang 			.cra_priority		= 400,
381a7ee22eeSTianjia Zhang 			.cra_flags		= CRYPTO_ALG_INTERNAL,
382a7ee22eeSTianjia Zhang 			.cra_blocksize		= SM4_BLOCK_SIZE,
383a7ee22eeSTianjia Zhang 			.cra_ctxsize		= sizeof(struct sm4_ctx),
384a7ee22eeSTianjia Zhang 			.cra_module		= THIS_MODULE,
385a7ee22eeSTianjia Zhang 		},
386a7ee22eeSTianjia Zhang 		.min_keysize	= SM4_KEY_SIZE,
387a7ee22eeSTianjia Zhang 		.max_keysize	= SM4_KEY_SIZE,
388a7ee22eeSTianjia Zhang 		.walksize	= 8 * SM4_BLOCK_SIZE,
389a7ee22eeSTianjia Zhang 		.setkey		= sm4_skcipher_setkey,
390*de79d9aaSTianjia Zhang 		.encrypt	= sm4_avx_ecb_encrypt,
391*de79d9aaSTianjia Zhang 		.decrypt	= sm4_avx_ecb_decrypt,
392a7ee22eeSTianjia Zhang 	}, {
393a7ee22eeSTianjia Zhang 		.base = {
394a7ee22eeSTianjia Zhang 			.cra_name		= "__cbc(sm4)",
395a7ee22eeSTianjia Zhang 			.cra_driver_name	= "__cbc-sm4-aesni-avx",
396a7ee22eeSTianjia Zhang 			.cra_priority		= 400,
397a7ee22eeSTianjia Zhang 			.cra_flags		= CRYPTO_ALG_INTERNAL,
398a7ee22eeSTianjia Zhang 			.cra_blocksize		= SM4_BLOCK_SIZE,
399a7ee22eeSTianjia Zhang 			.cra_ctxsize		= sizeof(struct sm4_ctx),
400a7ee22eeSTianjia Zhang 			.cra_module		= THIS_MODULE,
401a7ee22eeSTianjia Zhang 		},
402a7ee22eeSTianjia Zhang 		.min_keysize	= SM4_KEY_SIZE,
403a7ee22eeSTianjia Zhang 		.max_keysize	= SM4_KEY_SIZE,
404a7ee22eeSTianjia Zhang 		.ivsize		= SM4_BLOCK_SIZE,
405a7ee22eeSTianjia Zhang 		.walksize	= 8 * SM4_BLOCK_SIZE,
406a7ee22eeSTianjia Zhang 		.setkey		= sm4_skcipher_setkey,
407*de79d9aaSTianjia Zhang 		.encrypt	= sm4_cbc_encrypt,
408a7ee22eeSTianjia Zhang 		.decrypt	= cbc_decrypt,
409a7ee22eeSTianjia Zhang 	}, {
410a7ee22eeSTianjia Zhang 		.base = {
411a7ee22eeSTianjia Zhang 			.cra_name		= "__cfb(sm4)",
412a7ee22eeSTianjia Zhang 			.cra_driver_name	= "__cfb-sm4-aesni-avx",
413a7ee22eeSTianjia Zhang 			.cra_priority		= 400,
414a7ee22eeSTianjia Zhang 			.cra_flags		= CRYPTO_ALG_INTERNAL,
415a7ee22eeSTianjia Zhang 			.cra_blocksize		= 1,
416a7ee22eeSTianjia Zhang 			.cra_ctxsize		= sizeof(struct sm4_ctx),
417a7ee22eeSTianjia Zhang 			.cra_module		= THIS_MODULE,
418a7ee22eeSTianjia Zhang 		},
419a7ee22eeSTianjia Zhang 		.min_keysize	= SM4_KEY_SIZE,
420a7ee22eeSTianjia Zhang 		.max_keysize	= SM4_KEY_SIZE,
421a7ee22eeSTianjia Zhang 		.ivsize		= SM4_BLOCK_SIZE,
422a7ee22eeSTianjia Zhang 		.chunksize	= SM4_BLOCK_SIZE,
423a7ee22eeSTianjia Zhang 		.walksize	= 8 * SM4_BLOCK_SIZE,
424a7ee22eeSTianjia Zhang 		.setkey		= sm4_skcipher_setkey,
425*de79d9aaSTianjia Zhang 		.encrypt	= sm4_cfb_encrypt,
426a7ee22eeSTianjia Zhang 		.decrypt	= cfb_decrypt,
427a7ee22eeSTianjia Zhang 	}, {
428a7ee22eeSTianjia Zhang 		.base = {
429a7ee22eeSTianjia Zhang 			.cra_name		= "__ctr(sm4)",
430a7ee22eeSTianjia Zhang 			.cra_driver_name	= "__ctr-sm4-aesni-avx",
431a7ee22eeSTianjia Zhang 			.cra_priority		= 400,
432a7ee22eeSTianjia Zhang 			.cra_flags		= CRYPTO_ALG_INTERNAL,
433a7ee22eeSTianjia Zhang 			.cra_blocksize		= 1,
434a7ee22eeSTianjia Zhang 			.cra_ctxsize		= sizeof(struct sm4_ctx),
435a7ee22eeSTianjia Zhang 			.cra_module		= THIS_MODULE,
436a7ee22eeSTianjia Zhang 		},
437a7ee22eeSTianjia Zhang 		.min_keysize	= SM4_KEY_SIZE,
438a7ee22eeSTianjia Zhang 		.max_keysize	= SM4_KEY_SIZE,
439a7ee22eeSTianjia Zhang 		.ivsize		= SM4_BLOCK_SIZE,
440a7ee22eeSTianjia Zhang 		.chunksize	= SM4_BLOCK_SIZE,
441a7ee22eeSTianjia Zhang 		.walksize	= 8 * SM4_BLOCK_SIZE,
442a7ee22eeSTianjia Zhang 		.setkey		= sm4_skcipher_setkey,
443a7ee22eeSTianjia Zhang 		.encrypt	= ctr_crypt,
444a7ee22eeSTianjia Zhang 		.decrypt	= ctr_crypt,
445a7ee22eeSTianjia Zhang 	}
446a7ee22eeSTianjia Zhang };
447a7ee22eeSTianjia Zhang 
448a7ee22eeSTianjia Zhang static struct simd_skcipher_alg *
449a7ee22eeSTianjia Zhang simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
450a7ee22eeSTianjia Zhang 
sm4_init(void)451a7ee22eeSTianjia Zhang static int __init sm4_init(void)
452a7ee22eeSTianjia Zhang {
453a7ee22eeSTianjia Zhang 	const char *feature_name;
454a7ee22eeSTianjia Zhang 
455a7ee22eeSTianjia Zhang 	if (!boot_cpu_has(X86_FEATURE_AVX) ||
456a7ee22eeSTianjia Zhang 	    !boot_cpu_has(X86_FEATURE_AES) ||
457a7ee22eeSTianjia Zhang 	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
458a7ee22eeSTianjia Zhang 		pr_info("AVX or AES-NI instructions are not detected.\n");
459a7ee22eeSTianjia Zhang 		return -ENODEV;
460a7ee22eeSTianjia Zhang 	}
461a7ee22eeSTianjia Zhang 
462a7ee22eeSTianjia Zhang 	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
463a7ee22eeSTianjia Zhang 				&feature_name)) {
464a7ee22eeSTianjia Zhang 		pr_info("CPU feature '%s' is not supported.\n", feature_name);
465a7ee22eeSTianjia Zhang 		return -ENODEV;
466a7ee22eeSTianjia Zhang 	}
467a7ee22eeSTianjia Zhang 
468a7ee22eeSTianjia Zhang 	return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
469a7ee22eeSTianjia Zhang 					ARRAY_SIZE(sm4_aesni_avx_skciphers),
470a7ee22eeSTianjia Zhang 					simd_sm4_aesni_avx_skciphers);
471a7ee22eeSTianjia Zhang }
472a7ee22eeSTianjia Zhang 
sm4_exit(void)473a7ee22eeSTianjia Zhang static void __exit sm4_exit(void)
474a7ee22eeSTianjia Zhang {
475a7ee22eeSTianjia Zhang 	simd_unregister_skciphers(sm4_aesni_avx_skciphers,
476a7ee22eeSTianjia Zhang 					ARRAY_SIZE(sm4_aesni_avx_skciphers),
477a7ee22eeSTianjia Zhang 					simd_sm4_aesni_avx_skciphers);
478a7ee22eeSTianjia Zhang }
479a7ee22eeSTianjia Zhang 
480a7ee22eeSTianjia Zhang module_init(sm4_init);
481a7ee22eeSTianjia Zhang module_exit(sm4_exit);
482a7ee22eeSTianjia Zhang 
483a7ee22eeSTianjia Zhang MODULE_LICENSE("GPL v2");
484a7ee22eeSTianjia Zhang MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
485a7ee22eeSTianjia Zhang MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
486a7ee22eeSTianjia Zhang MODULE_ALIAS_CRYPTO("sm4");
487a7ee22eeSTianjia Zhang MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
488