xref: /openbmc/u-boot/lib/rsa/rsa-verify.c (revision 0b304a24)
1 /*
2  * Copyright (c) 2013, Google Inc.
3  *
4  * SPDX-License-Identifier:	GPL-2.0+
5  */
6 
7 #ifndef USE_HOSTCC
8 #include <common.h>
9 #include <fdtdec.h>
10 #include <asm/types.h>
11 #include <asm/byteorder.h>
12 #include <asm/errno.h>
13 #include <asm/types.h>
14 #include <asm/unaligned.h>
15 #else
16 #include "fdt_host.h"
17 #include "mkimage.h"
18 #include <fdt_support.h>
19 #endif
20 #include <u-boot/rsa.h>
21 #include <u-boot/sha1.h>
22 #include <u-boot/sha256.h>
23 
24 #define UINT64_MULT32(v, multby)  (((uint64_t)(v)) * ((uint32_t)(multby)))
25 
26 #define get_unaligned_be32(a) fdt32_to_cpu(*(uint32_t *)a)
27 #define put_unaligned_be32(a, b) (*(uint32_t *)(b) = cpu_to_fdt32(a))
28 
29 /* Default public exponent for backward compatibility */
30 #define RSA_DEFAULT_PUBEXP	65537
31 
32 /**
33  * subtract_modulus() - subtract modulus from the given value
34  *
35  * @key:	Key containing modulus to subtract
36  * @num:	Number to subtract modulus from, as little endian word array
37  */
38 static void subtract_modulus(const struct rsa_public_key *key, uint32_t num[])
39 {
40 	int64_t acc = 0;
41 	uint i;
42 
43 	for (i = 0; i < key->len; i++) {
44 		acc += (uint64_t)num[i] - key->modulus[i];
45 		num[i] = (uint32_t)acc;
46 		acc >>= 32;
47 	}
48 }
49 
50 /**
51  * greater_equal_modulus() - check if a value is >= modulus
52  *
53  * @key:	Key containing modulus to check
54  * @num:	Number to check against modulus, as little endian word array
55  * @return 0 if num < modulus, 1 if num >= modulus
56  */
57 static int greater_equal_modulus(const struct rsa_public_key *key,
58 				 uint32_t num[])
59 {
60 	int i;
61 
62 	for (i = (int)key->len - 1; i >= 0; i--) {
63 		if (num[i] < key->modulus[i])
64 			return 0;
65 		if (num[i] > key->modulus[i])
66 			return 1;
67 	}
68 
69 	return 1;  /* equal */
70 }
71 
72 /**
73  * montgomery_mul_add_step() - Perform montgomery multiply-add step
74  *
75  * Operation: montgomery result[] += a * b[] / n0inv % modulus
76  *
77  * @key:	RSA key
78  * @result:	Place to put result, as little endian word array
79  * @a:		Multiplier
80  * @b:		Multiplicand, as little endian word array
81  */
82 static void montgomery_mul_add_step(const struct rsa_public_key *key,
83 		uint32_t result[], const uint32_t a, const uint32_t b[])
84 {
85 	uint64_t acc_a, acc_b;
86 	uint32_t d0;
87 	uint i;
88 
89 	acc_a = (uint64_t)a * b[0] + result[0];
90 	d0 = (uint32_t)acc_a * key->n0inv;
91 	acc_b = (uint64_t)d0 * key->modulus[0] + (uint32_t)acc_a;
92 	for (i = 1; i < key->len; i++) {
93 		acc_a = (acc_a >> 32) + (uint64_t)a * b[i] + result[i];
94 		acc_b = (acc_b >> 32) + (uint64_t)d0 * key->modulus[i] +
95 				(uint32_t)acc_a;
96 		result[i - 1] = (uint32_t)acc_b;
97 	}
98 
99 	acc_a = (acc_a >> 32) + (acc_b >> 32);
100 
101 	result[i - 1] = (uint32_t)acc_a;
102 
103 	if (acc_a >> 32)
104 		subtract_modulus(key, result);
105 }
106 
107 /**
108  * montgomery_mul() - Perform montgomery mutitply
109  *
110  * Operation: montgomery result[] = a[] * b[] / n0inv % modulus
111  *
112  * @key:	RSA key
113  * @result:	Place to put result, as little endian word array
114  * @a:		Multiplier, as little endian word array
115  * @b:		Multiplicand, as little endian word array
116  */
117 static void montgomery_mul(const struct rsa_public_key *key,
118 		uint32_t result[], uint32_t a[], const uint32_t b[])
119 {
120 	uint i;
121 
122 	for (i = 0; i < key->len; ++i)
123 		result[i] = 0;
124 	for (i = 0; i < key->len; ++i)
125 		montgomery_mul_add_step(key, result, a[i], b);
126 }
127 
128 /**
129  * num_pub_exponent_bits() - Number of bits in the public exponent
130  *
131  * @key:	RSA key
132  * @num_bits:	Storage for the number of public exponent bits
133  */
134 static int num_public_exponent_bits(const struct rsa_public_key *key,
135 		int *num_bits)
136 {
137 	uint64_t exponent;
138 	int exponent_bits;
139 	const uint max_bits = (sizeof(exponent) * 8);
140 
141 	exponent = key->exponent;
142 	exponent_bits = 0;
143 
144 	if (!exponent) {
145 		*num_bits = exponent_bits;
146 		return 0;
147 	}
148 
149 	for (exponent_bits = 1; exponent_bits < max_bits + 1; ++exponent_bits)
150 		if (!(exponent >>= 1)) {
151 			*num_bits = exponent_bits;
152 			return 0;
153 		}
154 
155 	return -EINVAL;
156 }
157 
158 /**
159  * is_public_exponent_bit_set() - Check if a bit in the public exponent is set
160  *
161  * @key:	RSA key
162  * @pos:	The bit position to check
163  */
164 static int is_public_exponent_bit_set(const struct rsa_public_key *key,
165 		int pos)
166 {
167 	return key->exponent & (1ULL << pos);
168 }
169 
170 /**
171  * pow_mod() - in-place public exponentiation
172  *
173  * @key:	RSA key
174  * @inout:	Big-endian word array containing value and result
175  */
176 static int pow_mod(const struct rsa_public_key *key, uint32_t *inout)
177 {
178 	uint32_t *result, *ptr;
179 	uint i;
180 	int j, k;
181 
182 	/* Sanity check for stack size - key->len is in 32-bit words */
183 	if (key->len > RSA_MAX_KEY_BITS / 32) {
184 		debug("RSA key words %u exceeds maximum %d\n", key->len,
185 		      RSA_MAX_KEY_BITS / 32);
186 		return -EINVAL;
187 	}
188 
189 	uint32_t val[key->len], acc[key->len], tmp[key->len];
190 	uint32_t a_scaled[key->len];
191 	result = tmp;  /* Re-use location. */
192 
193 	/* Convert from big endian byte array to little endian word array. */
194 	for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--)
195 		val[i] = get_unaligned_be32(ptr);
196 
197 	if (0 != num_public_exponent_bits(key, &k))
198 		return -EINVAL;
199 
200 	if (k < 2) {
201 		debug("Public exponent is too short (%d bits, minimum 2)\n",
202 		      k);
203 		return -EINVAL;
204 	}
205 
206 	if (!is_public_exponent_bit_set(key, 0)) {
207 		debug("LSB of RSA public exponent must be set.\n");
208 		return -EINVAL;
209 	}
210 
211 	/* the bit at e[k-1] is 1 by definition, so start with: C := M */
212 	montgomery_mul(key, acc, val, key->rr); /* acc = a * RR / R mod n */
213 	/* retain scaled version for intermediate use */
214 	memcpy(a_scaled, acc, key->len * sizeof(a_scaled[0]));
215 
216 	for (j = k - 2; j > 0; --j) {
217 		montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
218 
219 		if (is_public_exponent_bit_set(key, j)) {
220 			/* acc = tmp * val / R mod n */
221 			montgomery_mul(key, acc, tmp, a_scaled);
222 		} else {
223 			/* e[j] == 0, copy tmp back to acc for next operation */
224 			memcpy(acc, tmp, key->len * sizeof(acc[0]));
225 		}
226 	}
227 
228 	/* the bit at e[0] is always 1 */
229 	montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
230 	montgomery_mul(key, acc, tmp, val); /* acc = tmp * a / R mod M */
231 	memcpy(result, acc, key->len * sizeof(result[0]));
232 
233 	/* Make sure result < mod; result is at most 1x mod too large. */
234 	if (greater_equal_modulus(key, result))
235 		subtract_modulus(key, result);
236 
237 	/* Convert to bigendian byte array */
238 	for (i = key->len - 1, ptr = inout; (int)i >= 0; i--, ptr++)
239 		put_unaligned_be32(result[i], ptr);
240 	return 0;
241 }
242 
243 static int rsa_verify_key(const struct rsa_public_key *key, const uint8_t *sig,
244 			  const uint32_t sig_len, const uint8_t *hash,
245 			  struct checksum_algo *algo)
246 {
247 	const uint8_t *padding;
248 	int pad_len;
249 	int ret;
250 
251 	if (!key || !sig || !hash || !algo)
252 		return -EIO;
253 
254 	if (sig_len != (key->len * sizeof(uint32_t))) {
255 		debug("Signature is of incorrect length %d\n", sig_len);
256 		return -EINVAL;
257 	}
258 
259 	debug("Checksum algorithm: %s", algo->name);
260 
261 	/* Sanity check for stack size */
262 	if (sig_len > RSA_MAX_SIG_BITS / 8) {
263 		debug("Signature length %u exceeds maximum %d\n", sig_len,
264 		      RSA_MAX_SIG_BITS / 8);
265 		return -EINVAL;
266 	}
267 
268 	uint32_t buf[sig_len / sizeof(uint32_t)];
269 
270 	memcpy(buf, sig, sig_len);
271 
272 	ret = pow_mod(key, buf);
273 	if (ret)
274 		return ret;
275 
276 	padding = algo->rsa_padding;
277 	pad_len = algo->pad_len - algo->checksum_len;
278 
279 	/* Check pkcs1.5 padding bytes. */
280 	if (memcmp(buf, padding, pad_len)) {
281 		debug("In RSAVerify(): Padding check failed!\n");
282 		return -EINVAL;
283 	}
284 
285 	/* Check hash. */
286 	if (memcmp((uint8_t *)buf + pad_len, hash, sig_len - pad_len)) {
287 		debug("In RSAVerify(): Hash check failed!\n");
288 		return -EACCES;
289 	}
290 
291 	return 0;
292 }
293 
294 static void rsa_convert_big_endian(uint32_t *dst, const uint32_t *src, int len)
295 {
296 	int i;
297 
298 	for (i = 0; i < len; i++)
299 		dst[i] = fdt32_to_cpu(src[len - 1 - i]);
300 }
301 
302 static int rsa_verify_with_keynode(struct image_sign_info *info,
303 		const void *hash, uint8_t *sig, uint sig_len, int node)
304 {
305 	const void *blob = info->fdt_blob;
306 	struct rsa_public_key key;
307 	const void *modulus, *rr;
308 	const uint64_t *public_exponent;
309 	int length;
310 	int ret;
311 
312 	if (node < 0) {
313 		debug("%s: Skipping invalid node", __func__);
314 		return -EBADF;
315 	}
316 	if (!fdt_getprop(blob, node, "rsa,n0-inverse", NULL)) {
317 		debug("%s: Missing rsa,n0-inverse", __func__);
318 		return -EFAULT;
319 	}
320 	key.len = fdtdec_get_int(blob, node, "rsa,num-bits", 0);
321 	key.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0);
322 	public_exponent = fdt_getprop(blob, node, "rsa,exponent", &length);
323 	if (!public_exponent || length < sizeof(*public_exponent))
324 		key.exponent = RSA_DEFAULT_PUBEXP;
325 	else
326 		key.exponent = fdt64_to_cpu(*public_exponent);
327 	modulus = fdt_getprop(blob, node, "rsa,modulus", NULL);
328 	rr = fdt_getprop(blob, node, "rsa,r-squared", NULL);
329 	if (!key.len || !modulus || !rr) {
330 		debug("%s: Missing RSA key info", __func__);
331 		return -EFAULT;
332 	}
333 
334 	/* Sanity check for stack size */
335 	if (key.len > RSA_MAX_KEY_BITS || key.len < RSA_MIN_KEY_BITS) {
336 		debug("RSA key bits %u outside allowed range %d..%d\n",
337 		      key.len, RSA_MIN_KEY_BITS, RSA_MAX_KEY_BITS);
338 		return -EFAULT;
339 	}
340 	key.len /= sizeof(uint32_t) * 8;
341 	uint32_t key1[key.len], key2[key.len];
342 
343 	key.modulus = key1;
344 	key.rr = key2;
345 	rsa_convert_big_endian(key.modulus, modulus, key.len);
346 	rsa_convert_big_endian(key.rr, rr, key.len);
347 	if (!key.modulus || !key.rr) {
348 		debug("%s: Out of memory", __func__);
349 		return -ENOMEM;
350 	}
351 
352 	debug("key length %d\n", key.len);
353 	ret = rsa_verify_key(&key, sig, sig_len, hash, info->algo->checksum);
354 	if (ret) {
355 		printf("%s: RSA failed to verify: %d\n", __func__, ret);
356 		return ret;
357 	}
358 
359 	return 0;
360 }
361 
362 int rsa_verify(struct image_sign_info *info,
363 	       const struct image_region region[], int region_count,
364 	       uint8_t *sig, uint sig_len)
365 {
366 	const void *blob = info->fdt_blob;
367 	/* Reserve memory for maximum checksum-length */
368 	uint8_t hash[info->algo->checksum->pad_len];
369 	int ndepth, noffset;
370 	int sig_node, node;
371 	char name[100];
372 	int ret;
373 
374 	/*
375 	 * Verify that the checksum-length does not exceed the
376 	 * rsa-signature-length
377 	 */
378 	if (info->algo->checksum->checksum_len >
379 	    info->algo->checksum->pad_len) {
380 		debug("%s: invlaid checksum-algorithm %s for %s\n",
381 		      __func__, info->algo->checksum->name, info->algo->name);
382 		return -EINVAL;
383 	}
384 
385 	sig_node = fdt_subnode_offset(blob, 0, FIT_SIG_NODENAME);
386 	if (sig_node < 0) {
387 		debug("%s: No signature node found\n", __func__);
388 		return -ENOENT;
389 	}
390 
391 	/* Calculate checksum with checksum-algorithm */
392 	info->algo->checksum->calculate(region, region_count, hash);
393 
394 	/* See if we must use a particular key */
395 	if (info->required_keynode != -1) {
396 		ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
397 			info->required_keynode);
398 		if (!ret)
399 			return ret;
400 	}
401 
402 	/* Look for a key that matches our hint */
403 	snprintf(name, sizeof(name), "key-%s", info->keyname);
404 	node = fdt_subnode_offset(blob, sig_node, name);
405 	ret = rsa_verify_with_keynode(info, hash, sig, sig_len, node);
406 	if (!ret)
407 		return ret;
408 
409 	/* No luck, so try each of the keys in turn */
410 	for (ndepth = 0, noffset = fdt_next_node(info->fit, sig_node, &ndepth);
411 			(noffset >= 0) && (ndepth > 0);
412 			noffset = fdt_next_node(info->fit, noffset, &ndepth)) {
413 		if (ndepth == 1 && noffset != node) {
414 			ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
415 						      noffset);
416 			if (!ret)
417 				break;
418 		}
419 	}
420 
421 	return ret;
422 }
423