xref: /openbmc/u-boot/lib/rsa/rsa-verify.c (revision 2cb0e55a)
1 /*
2  * Copyright (c) 2013, Google Inc.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License as
6  * published by the Free Software Foundation; either version 2 of
7  * the License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 59 Temple Place, Suite 330, Boston,
17  * MA 02111-1307 USA
18  */
19 
20 #include <common.h>
21 #include <fdtdec.h>
22 #include <rsa.h>
23 #include <sha1.h>
24 #include <asm/byteorder.h>
25 #include <asm/errno.h>
26 #include <asm/unaligned.h>
27 
28 /**
29  * struct rsa_public_key - holder for a public key
30  *
31  * An RSA public key consists of a modulus (typically called N), the inverse
32  * and R^2, where R is 2^(# key bits).
33  */
34 struct rsa_public_key {
35 	uint len;		/* Length of modulus[] in number of uint32_t */
36 	uint32_t n0inv;		/* -1 / modulus[0] mod 2^32 */
37 	uint32_t *modulus;	/* modulus as little endian array */
38 	uint32_t *rr;		/* R^2 as little endian array */
39 };
40 
41 #define UINT64_MULT32(v, multby)  (((uint64_t)(v)) * ((uint32_t)(multby)))
42 
43 #define RSA2048_BYTES	(2048 / 8)
44 
45 /* This is the minimum/maximum key size we support, in bits */
46 #define RSA_MIN_KEY_BITS	2048
47 #define RSA_MAX_KEY_BITS	2048
48 
49 /* This is the maximum signature length that we support, in bits */
50 #define RSA_MAX_SIG_BITS	2048
51 
52 static const uint8_t padding_sha1_rsa2048[RSA2048_BYTES - SHA1_SUM_LEN] = {
53 	0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
54 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
55 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
56 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
57 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
58 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
59 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
60 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
61 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
62 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
63 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
64 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
65 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
66 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
67 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
68 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
69 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
70 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
71 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
72 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
73 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
74 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
75 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
76 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
77 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
78 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
79 	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
80 	0xff, 0xff, 0xff, 0xff, 0x00, 0x30, 0x21, 0x30,
81 	0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a,
82 	0x05, 0x00, 0x04, 0x14
83 };
84 
85 /**
86  * subtract_modulus() - subtract modulus from the given value
87  *
88  * @key:	Key containing modulus to subtract
89  * @num:	Number to subtract modulus from, as little endian word array
90  */
91 static void subtract_modulus(const struct rsa_public_key *key, uint32_t num[])
92 {
93 	int64_t acc = 0;
94 	uint i;
95 
96 	for (i = 0; i < key->len; i++) {
97 		acc += (uint64_t)num[i] - key->modulus[i];
98 		num[i] = (uint32_t)acc;
99 		acc >>= 32;
100 	}
101 }
102 
103 /**
104  * greater_equal_modulus() - check if a value is >= modulus
105  *
106  * @key:	Key containing modulus to check
107  * @num:	Number to check against modulus, as little endian word array
108  * @return 0 if num < modulus, 1 if num >= modulus
109  */
110 static int greater_equal_modulus(const struct rsa_public_key *key,
111 				 uint32_t num[])
112 {
113 	uint32_t i;
114 
115 	for (i = key->len - 1; i >= 0; i--) {
116 		if (num[i] < key->modulus[i])
117 			return 0;
118 		if (num[i] > key->modulus[i])
119 			return 1;
120 	}
121 
122 	return 1;  /* equal */
123 }
124 
125 /**
126  * montgomery_mul_add_step() - Perform montgomery multiply-add step
127  *
128  * Operation: montgomery result[] += a * b[] / n0inv % modulus
129  *
130  * @key:	RSA key
131  * @result:	Place to put result, as little endian word array
132  * @a:		Multiplier
133  * @b:		Multiplicand, as little endian word array
134  */
135 static void montgomery_mul_add_step(const struct rsa_public_key *key,
136 		uint32_t result[], const uint32_t a, const uint32_t b[])
137 {
138 	uint64_t acc_a, acc_b;
139 	uint32_t d0;
140 	uint i;
141 
142 	acc_a = (uint64_t)a * b[0] + result[0];
143 	d0 = (uint32_t)acc_a * key->n0inv;
144 	acc_b = (uint64_t)d0 * key->modulus[0] + (uint32_t)acc_a;
145 	for (i = 1; i < key->len; i++) {
146 		acc_a = (acc_a >> 32) + (uint64_t)a * b[i] + result[i];
147 		acc_b = (acc_b >> 32) + (uint64_t)d0 * key->modulus[i] +
148 				(uint32_t)acc_a;
149 		result[i - 1] = (uint32_t)acc_b;
150 	}
151 
152 	acc_a = (acc_a >> 32) + (acc_b >> 32);
153 
154 	result[i - 1] = (uint32_t)acc_a;
155 
156 	if (acc_a >> 32)
157 		subtract_modulus(key, result);
158 }
159 
160 /**
161  * montgomery_mul() - Perform montgomery mutitply
162  *
163  * Operation: montgomery result[] = a[] * b[] / n0inv % modulus
164  *
165  * @key:	RSA key
166  * @result:	Place to put result, as little endian word array
167  * @a:		Multiplier, as little endian word array
168  * @b:		Multiplicand, as little endian word array
169  */
170 static void montgomery_mul(const struct rsa_public_key *key,
171 		uint32_t result[], uint32_t a[], const uint32_t b[])
172 {
173 	uint i;
174 
175 	for (i = 0; i < key->len; ++i)
176 		result[i] = 0;
177 	for (i = 0; i < key->len; ++i)
178 		montgomery_mul_add_step(key, result, a[i], b);
179 }
180 
181 /**
182  * pow_mod() - in-place public exponentiation
183  *
184  * @key:	RSA key
185  * @inout:	Big-endian word array containing value and result
186  */
187 static int pow_mod(const struct rsa_public_key *key, uint32_t *inout)
188 {
189 	uint32_t *result, *ptr;
190 	uint i;
191 
192 	/* Sanity check for stack size - key->len is in 32-bit words */
193 	if (key->len > RSA_MAX_KEY_BITS / 32) {
194 		debug("RSA key words %u exceeds maximum %d\n", key->len,
195 		      RSA_MAX_KEY_BITS / 32);
196 		return -EINVAL;
197 	}
198 
199 	uint32_t val[key->len], acc[key->len], tmp[key->len];
200 	result = tmp;  /* Re-use location. */
201 
202 	/* Convert from big endian byte array to little endian word array. */
203 	for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--)
204 		val[i] = get_unaligned_be32(ptr);
205 
206 	montgomery_mul(key, acc, val, key->rr);  /* axx = a * RR / R mod M */
207 	for (i = 0; i < 16; i += 2) {
208 		montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod M */
209 		montgomery_mul(key, acc, tmp, tmp); /* acc = tmp^2 / R mod M */
210 	}
211 	montgomery_mul(key, result, acc, val);  /* result = XX * a / R mod M */
212 
213 	/* Make sure result < mod; result is at most 1x mod too large. */
214 	if (greater_equal_modulus(key, result))
215 		subtract_modulus(key, result);
216 
217 	/* Convert to bigendian byte array */
218 	for (i = key->len - 1, ptr = inout; (int)i >= 0; i--, ptr++)
219 		put_unaligned_be32(result[i], ptr);
220 
221 	return 0;
222 }
223 
224 static int rsa_verify_key(const struct rsa_public_key *key, const uint8_t *sig,
225 		const uint32_t sig_len, const uint8_t *hash)
226 {
227 	const uint8_t *padding;
228 	int pad_len;
229 	int ret;
230 
231 	if (!key || !sig || !hash)
232 		return -EIO;
233 
234 	if (sig_len != (key->len * sizeof(uint32_t))) {
235 		debug("Signature is of incorrect length %d\n", sig_len);
236 		return -EINVAL;
237 	}
238 
239 	/* Sanity check for stack size */
240 	if (sig_len > RSA_MAX_SIG_BITS / 8) {
241 		debug("Signature length %u exceeds maximum %d\n", sig_len,
242 		      RSA_MAX_SIG_BITS / 8);
243 		return -EINVAL;
244 	}
245 
246 	uint32_t buf[sig_len / sizeof(uint32_t)];
247 
248 	memcpy(buf, sig, sig_len);
249 
250 	ret = pow_mod(key, buf);
251 	if (ret)
252 		return ret;
253 
254 	/* Determine padding to use depending on the signature type. */
255 	padding = padding_sha1_rsa2048;
256 	pad_len = RSA2048_BYTES - SHA1_SUM_LEN;
257 
258 	/* Check pkcs1.5 padding bytes. */
259 	if (memcmp(buf, padding, pad_len)) {
260 		debug("In RSAVerify(): Padding check failed!\n");
261 		return -EINVAL;
262 	}
263 
264 	/* Check hash. */
265 	if (memcmp((uint8_t *)buf + pad_len, hash, sig_len - pad_len)) {
266 		debug("In RSAVerify(): Hash check failed!\n");
267 		return -EACCES;
268 	}
269 
270 	return 0;
271 }
272 
273 static void rsa_convert_big_endian(uint32_t *dst, const uint32_t *src, int len)
274 {
275 	int i;
276 
277 	for (i = 0; i < len; i++)
278 		dst[i] = fdt32_to_cpu(src[len - 1 - i]);
279 }
280 
281 static int rsa_verify_with_keynode(struct image_sign_info *info,
282 		const void *hash, uint8_t *sig, uint sig_len, int node)
283 {
284 	const void *blob = info->fdt_blob;
285 	struct rsa_public_key key;
286 	const void *modulus, *rr;
287 	int ret;
288 
289 	if (node < 0) {
290 		debug("%s: Skipping invalid node", __func__);
291 		return -EBADF;
292 	}
293 	if (!fdt_getprop(blob, node, "rsa,n0-inverse", NULL)) {
294 		debug("%s: Missing rsa,n0-inverse", __func__);
295 		return -EFAULT;
296 	}
297 	key.len = fdtdec_get_int(blob, node, "rsa,num-bits", 0);
298 	key.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0);
299 	modulus = fdt_getprop(blob, node, "rsa,modulus", NULL);
300 	rr = fdt_getprop(blob, node, "rsa,r-squared", NULL);
301 	if (!key.len || !modulus || !rr) {
302 		debug("%s: Missing RSA key info", __func__);
303 		return -EFAULT;
304 	}
305 
306 	/* Sanity check for stack size */
307 	if (key.len > RSA_MAX_KEY_BITS || key.len < RSA_MIN_KEY_BITS) {
308 		debug("RSA key bits %u outside allowed range %d..%d\n",
309 		      key.len, RSA_MIN_KEY_BITS, RSA_MAX_KEY_BITS);
310 		return -EFAULT;
311 	}
312 	key.len /= sizeof(uint32_t) * 8;
313 	uint32_t key1[key.len], key2[key.len];
314 
315 	key.modulus = key1;
316 	key.rr = key2;
317 	rsa_convert_big_endian(key.modulus, modulus, key.len);
318 	rsa_convert_big_endian(key.rr, rr, key.len);
319 	if (!key.modulus || !key.rr) {
320 		debug("%s: Out of memory", __func__);
321 		return -ENOMEM;
322 	}
323 
324 	debug("key length %d\n", key.len);
325 	ret = rsa_verify_key(&key, sig, sig_len, hash);
326 	if (ret) {
327 		printf("%s: RSA failed to verify: %d\n", __func__, ret);
328 		return ret;
329 	}
330 
331 	return 0;
332 }
333 
334 int rsa_verify(struct image_sign_info *info,
335 	       const struct image_region region[], int region_count,
336 	       uint8_t *sig, uint sig_len)
337 {
338 	const void *blob = info->fdt_blob;
339 	uint8_t hash[SHA1_SUM_LEN];
340 	int ndepth, noffset;
341 	int sig_node, node;
342 	char name[100];
343 	sha1_context ctx;
344 	int ret, i;
345 
346 	sig_node = fdt_subnode_offset(blob, 0, FIT_SIG_NODENAME);
347 	if (sig_node < 0) {
348 		debug("%s: No signature node found\n", __func__);
349 		return -ENOENT;
350 	}
351 
352 	sha1_starts(&ctx);
353 	for (i = 0; i < region_count; i++)
354 		sha1_update(&ctx, region[i].data, region[i].size);
355 	sha1_finish(&ctx, hash);
356 
357 	/* See if we must use a particular key */
358 	if (info->required_keynode != -1) {
359 		ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
360 			info->required_keynode);
361 		if (!ret)
362 			return ret;
363 	}
364 
365 	/* Look for a key that matches our hint */
366 	snprintf(name, sizeof(name), "key-%s", info->keyname);
367 	node = fdt_subnode_offset(blob, sig_node, name);
368 	ret = rsa_verify_with_keynode(info, hash, sig, sig_len, node);
369 	if (!ret)
370 		return ret;
371 
372 	/* No luck, so try each of the keys in turn */
373 	for (ndepth = 0, noffset = fdt_next_node(info->fit, sig_node, &ndepth);
374 			(noffset >= 0) && (ndepth > 0);
375 			noffset = fdt_next_node(info->fit, noffset, &ndepth)) {
376 		if (ndepth == 1 && noffset != node) {
377 			ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
378 						      noffset);
379 			if (!ret)
380 				break;
381 		}
382 	}
383 
384 	return ret;
385 }
386