1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4-CCM AEAD Algorithm using ARMv8 Crypto Extensions
4  * as specified in rfc8998
5  * https://datatracker.ietf.org/doc/html/rfc8998
6  *
7  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
8  */
9 
10 #include <linux/module.h>
11 #include <linux/crypto.h>
12 #include <linux/kernel.h>
13 #include <linux/cpufeature.h>
14 #include <asm/neon.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/internal/aead.h>
17 #include <crypto/internal/skcipher.h>
18 #include <crypto/sm4.h>
19 #include "sm4-ce.h"
20 
21 asmlinkage void sm4_ce_cbcmac_update(const u32 *rkey_enc, u8 *mac,
22 				     const u8 *src, unsigned int nblocks);
23 asmlinkage void sm4_ce_ccm_enc(const u32 *rkey_enc, u8 *dst, const u8 *src,
24 			       u8 *iv, unsigned int nbytes, u8 *mac);
25 asmlinkage void sm4_ce_ccm_dec(const u32 *rkey_enc, u8 *dst, const u8 *src,
26 			       u8 *iv, unsigned int nbytes, u8 *mac);
27 asmlinkage void sm4_ce_ccm_final(const u32 *rkey_enc, u8 *iv, u8 *mac);
28 
29 
30 static int ccm_setkey(struct crypto_aead *tfm, const u8 *key,
31 		      unsigned int key_len)
32 {
33 	struct sm4_ctx *ctx = crypto_aead_ctx(tfm);
34 
35 	if (key_len != SM4_KEY_SIZE)
36 		return -EINVAL;
37 
38 	kernel_neon_begin();
39 	sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
40 			  crypto_sm4_fk, crypto_sm4_ck);
41 	kernel_neon_end();
42 
43 	return 0;
44 }
45 
46 static int ccm_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
47 {
48 	if ((authsize & 1) || authsize < 4)
49 		return -EINVAL;
50 	return 0;
51 }
52 
53 static int ccm_format_input(u8 info[], struct aead_request *req,
54 			    unsigned int msglen)
55 {
56 	struct crypto_aead *aead = crypto_aead_reqtfm(req);
57 	unsigned int l = req->iv[0] + 1;
58 	unsigned int m;
59 	__be32 len;
60 
61 	/* verify that CCM dimension 'L': 2 <= L <= 8 */
62 	if (l < 2 || l > 8)
63 		return -EINVAL;
64 	if (l < 4 && msglen >> (8 * l))
65 		return -EOVERFLOW;
66 
67 	memset(&req->iv[SM4_BLOCK_SIZE - l], 0, l);
68 
69 	memcpy(info, req->iv, SM4_BLOCK_SIZE);
70 
71 	m = crypto_aead_authsize(aead);
72 
73 	/* format flags field per RFC 3610/NIST 800-38C */
74 	*info |= ((m - 2) / 2) << 3;
75 	if (req->assoclen)
76 		*info |= (1 << 6);
77 
78 	/*
79 	 * format message length field,
80 	 * Linux uses a u32 type to represent msglen
81 	 */
82 	if (l >= 4)
83 		l = 4;
84 
85 	len = cpu_to_be32(msglen);
86 	memcpy(&info[SM4_BLOCK_SIZE - l], (u8 *)&len + 4 - l, l);
87 
88 	return 0;
89 }
90 
91 static void ccm_calculate_auth_mac(struct aead_request *req, u8 mac[])
92 {
93 	struct crypto_aead *aead = crypto_aead_reqtfm(req);
94 	struct sm4_ctx *ctx = crypto_aead_ctx(aead);
95 	struct __packed { __be16 l; __be32 h; } aadlen;
96 	u32 assoclen = req->assoclen;
97 	struct scatter_walk walk;
98 	unsigned int len;
99 
100 	if (assoclen < 0xff00) {
101 		aadlen.l = cpu_to_be16(assoclen);
102 		len = 2;
103 	} else {
104 		aadlen.l = cpu_to_be16(0xfffe);
105 		put_unaligned_be32(assoclen, &aadlen.h);
106 		len = 6;
107 	}
108 
109 	sm4_ce_crypt_block(ctx->rkey_enc, mac, mac);
110 	crypto_xor(mac, (const u8 *)&aadlen, len);
111 
112 	scatterwalk_start(&walk, req->src);
113 
114 	do {
115 		u32 n = scatterwalk_clamp(&walk, assoclen);
116 		u8 *p, *ptr;
117 
118 		if (!n) {
119 			scatterwalk_start(&walk, sg_next(walk.sg));
120 			n = scatterwalk_clamp(&walk, assoclen);
121 		}
122 
123 		p = ptr = scatterwalk_map(&walk);
124 		assoclen -= n;
125 		scatterwalk_advance(&walk, n);
126 
127 		while (n > 0) {
128 			unsigned int l, nblocks;
129 
130 			if (len == SM4_BLOCK_SIZE) {
131 				if (n < SM4_BLOCK_SIZE) {
132 					sm4_ce_crypt_block(ctx->rkey_enc,
133 							   mac, mac);
134 
135 					len = 0;
136 				} else {
137 					nblocks = n / SM4_BLOCK_SIZE;
138 					sm4_ce_cbcmac_update(ctx->rkey_enc,
139 							     mac, ptr, nblocks);
140 
141 					ptr += nblocks * SM4_BLOCK_SIZE;
142 					n %= SM4_BLOCK_SIZE;
143 
144 					continue;
145 				}
146 			}
147 
148 			l = min(n, SM4_BLOCK_SIZE - len);
149 			if (l) {
150 				crypto_xor(mac + len, ptr, l);
151 				len += l;
152 				ptr += l;
153 				n -= l;
154 			}
155 		}
156 
157 		scatterwalk_unmap(p);
158 		scatterwalk_done(&walk, 0, assoclen);
159 	} while (assoclen);
160 }
161 
162 static int ccm_crypt(struct aead_request *req, struct skcipher_walk *walk,
163 		     u32 *rkey_enc, u8 mac[],
164 		     void (*sm4_ce_ccm_crypt)(const u32 *rkey_enc, u8 *dst,
165 					const u8 *src, u8 *iv,
166 					unsigned int nbytes, u8 *mac))
167 {
168 	u8 __aligned(8) ctr0[SM4_BLOCK_SIZE];
169 	int err;
170 
171 	/* preserve the initial ctr0 for the TAG */
172 	memcpy(ctr0, walk->iv, SM4_BLOCK_SIZE);
173 	crypto_inc(walk->iv, SM4_BLOCK_SIZE);
174 
175 	kernel_neon_begin();
176 
177 	if (req->assoclen)
178 		ccm_calculate_auth_mac(req, mac);
179 
180 	do {
181 		unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
182 		const u8 *src = walk->src.virt.addr;
183 		u8 *dst = walk->dst.virt.addr;
184 
185 		if (walk->nbytes == walk->total)
186 			tail = 0;
187 
188 		if (walk->nbytes - tail)
189 			sm4_ce_ccm_crypt(rkey_enc, dst, src, walk->iv,
190 					 walk->nbytes - tail, mac);
191 
192 		if (walk->nbytes == walk->total)
193 			sm4_ce_ccm_final(rkey_enc, ctr0, mac);
194 
195 		kernel_neon_end();
196 
197 		if (walk->nbytes) {
198 			err = skcipher_walk_done(walk, tail);
199 			if (err)
200 				return err;
201 			if (walk->nbytes)
202 				kernel_neon_begin();
203 		}
204 	} while (walk->nbytes > 0);
205 
206 	return 0;
207 }
208 
209 static int ccm_encrypt(struct aead_request *req)
210 {
211 	struct crypto_aead *aead = crypto_aead_reqtfm(req);
212 	struct sm4_ctx *ctx = crypto_aead_ctx(aead);
213 	u8 __aligned(8) mac[SM4_BLOCK_SIZE];
214 	struct skcipher_walk walk;
215 	int err;
216 
217 	err = ccm_format_input(mac, req, req->cryptlen);
218 	if (err)
219 		return err;
220 
221 	err = skcipher_walk_aead_encrypt(&walk, req, false);
222 	if (err)
223 		return err;
224 
225 	err = ccm_crypt(req, &walk, ctx->rkey_enc, mac, sm4_ce_ccm_enc);
226 	if (err)
227 		return err;
228 
229 	/* copy authtag to end of dst */
230 	scatterwalk_map_and_copy(mac, req->dst, req->assoclen + req->cryptlen,
231 				 crypto_aead_authsize(aead), 1);
232 
233 	return 0;
234 }
235 
236 static int ccm_decrypt(struct aead_request *req)
237 {
238 	struct crypto_aead *aead = crypto_aead_reqtfm(req);
239 	unsigned int authsize = crypto_aead_authsize(aead);
240 	struct sm4_ctx *ctx = crypto_aead_ctx(aead);
241 	u8 __aligned(8) mac[SM4_BLOCK_SIZE];
242 	u8 authtag[SM4_BLOCK_SIZE];
243 	struct skcipher_walk walk;
244 	int err;
245 
246 	err = ccm_format_input(mac, req, req->cryptlen - authsize);
247 	if (err)
248 		return err;
249 
250 	err = skcipher_walk_aead_decrypt(&walk, req, false);
251 	if (err)
252 		return err;
253 
254 	err = ccm_crypt(req, &walk, ctx->rkey_enc, mac, sm4_ce_ccm_dec);
255 	if (err)
256 		return err;
257 
258 	/* compare calculated auth tag with the stored one */
259 	scatterwalk_map_and_copy(authtag, req->src,
260 				 req->assoclen + req->cryptlen - authsize,
261 				 authsize, 0);
262 
263 	if (crypto_memneq(authtag, mac, authsize))
264 		return -EBADMSG;
265 
266 	return 0;
267 }
268 
269 static struct aead_alg sm4_ccm_alg = {
270 	.base = {
271 		.cra_name		= "ccm(sm4)",
272 		.cra_driver_name	= "ccm-sm4-ce",
273 		.cra_priority		= 400,
274 		.cra_blocksize		= 1,
275 		.cra_ctxsize		= sizeof(struct sm4_ctx),
276 		.cra_module		= THIS_MODULE,
277 	},
278 	.ivsize		= SM4_BLOCK_SIZE,
279 	.chunksize	= SM4_BLOCK_SIZE,
280 	.maxauthsize	= SM4_BLOCK_SIZE,
281 	.setkey		= ccm_setkey,
282 	.setauthsize	= ccm_setauthsize,
283 	.encrypt	= ccm_encrypt,
284 	.decrypt	= ccm_decrypt,
285 };
286 
287 static int __init sm4_ce_ccm_init(void)
288 {
289 	return crypto_register_aead(&sm4_ccm_alg);
290 }
291 
292 static void __exit sm4_ce_ccm_exit(void)
293 {
294 	crypto_unregister_aead(&sm4_ccm_alg);
295 }
296 
297 module_cpu_feature_match(SM4, sm4_ce_ccm_init);
298 module_exit(sm4_ce_ccm_exit);
299 
300 MODULE_DESCRIPTION("Synchronous SM4 in CCM mode using ARMv8 Crypto Extensions");
301 MODULE_ALIAS_CRYPTO("ccm(sm4)");
302 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
303 MODULE_LICENSE("GPL v2");
304