xref: /openbmc/linux/crypto/ecc.c (revision 4e660291)
1 /*
2  * Copyright (c) 2013, 2014 Kenneth MacKay. All rights reserved.
3  * Copyright (c) 2019 Vitaly Chikunov <vt@altlinux.org>
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are
7  * met:
8  *  * Redistributions of source code must retain the above copyright
9  *   notice, this list of conditions and the following disclaimer.
10  *  * Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
18  * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
19  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
20  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
22  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  */
26 
27 #include <linux/module.h>
28 #include <linux/random.h>
29 #include <linux/slab.h>
30 #include <linux/swab.h>
31 #include <linux/fips.h>
32 #include <crypto/ecdh.h>
33 #include <crypto/rng.h>
34 #include <asm/unaligned.h>
35 #include <linux/ratelimit.h>
36 
37 #include "ecc.h"
38 #include "ecc_curve_defs.h"
39 
40 typedef struct {
41 	u64 m_low;
42 	u64 m_high;
43 } uint128_t;
44 
45 const struct ecc_curve *ecc_get_curve(unsigned int curve_id)
46 {
47 	switch (curve_id) {
48 	/* In FIPS mode only allow P256 and higher */
49 	case ECC_CURVE_NIST_P192:
50 		return fips_enabled ? NULL : &nist_p192;
51 	case ECC_CURVE_NIST_P256:
52 		return &nist_p256;
53 	default:
54 		return NULL;
55 	}
56 }
57 EXPORT_SYMBOL(ecc_get_curve);
58 
59 static u64 *ecc_alloc_digits_space(unsigned int ndigits)
60 {
61 	size_t len = ndigits * sizeof(u64);
62 
63 	if (!len)
64 		return NULL;
65 
66 	return kmalloc(len, GFP_KERNEL);
67 }
68 
69 static void ecc_free_digits_space(u64 *space)
70 {
71 	kfree_sensitive(space);
72 }
73 
74 static struct ecc_point *ecc_alloc_point(unsigned int ndigits)
75 {
76 	struct ecc_point *p = kmalloc(sizeof(*p), GFP_KERNEL);
77 
78 	if (!p)
79 		return NULL;
80 
81 	p->x = ecc_alloc_digits_space(ndigits);
82 	if (!p->x)
83 		goto err_alloc_x;
84 
85 	p->y = ecc_alloc_digits_space(ndigits);
86 	if (!p->y)
87 		goto err_alloc_y;
88 
89 	p->ndigits = ndigits;
90 
91 	return p;
92 
93 err_alloc_y:
94 	ecc_free_digits_space(p->x);
95 err_alloc_x:
96 	kfree(p);
97 	return NULL;
98 }
99 
100 static void ecc_free_point(struct ecc_point *p)
101 {
102 	if (!p)
103 		return;
104 
105 	kfree_sensitive(p->x);
106 	kfree_sensitive(p->y);
107 	kfree_sensitive(p);
108 }
109 
110 static void vli_clear(u64 *vli, unsigned int ndigits)
111 {
112 	int i;
113 
114 	for (i = 0; i < ndigits; i++)
115 		vli[i] = 0;
116 }
117 
118 /* Returns true if vli == 0, false otherwise. */
119 bool vli_is_zero(const u64 *vli, unsigned int ndigits)
120 {
121 	int i;
122 
123 	for (i = 0; i < ndigits; i++) {
124 		if (vli[i])
125 			return false;
126 	}
127 
128 	return true;
129 }
130 EXPORT_SYMBOL(vli_is_zero);
131 
132 /* Returns nonzero if bit bit of vli is set. */
133 static u64 vli_test_bit(const u64 *vli, unsigned int bit)
134 {
135 	return (vli[bit / 64] & ((u64)1 << (bit % 64)));
136 }
137 
138 static bool vli_is_negative(const u64 *vli, unsigned int ndigits)
139 {
140 	return vli_test_bit(vli, ndigits * 64 - 1);
141 }
142 
143 /* Counts the number of 64-bit "digits" in vli. */
144 static unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits)
145 {
146 	int i;
147 
148 	/* Search from the end until we find a non-zero digit.
149 	 * We do it in reverse because we expect that most digits will
150 	 * be nonzero.
151 	 */
152 	for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--);
153 
154 	return (i + 1);
155 }
156 
157 /* Counts the number of bits required for vli. */
158 static unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits)
159 {
160 	unsigned int i, num_digits;
161 	u64 digit;
162 
163 	num_digits = vli_num_digits(vli, ndigits);
164 	if (num_digits == 0)
165 		return 0;
166 
167 	digit = vli[num_digits - 1];
168 	for (i = 0; digit; i++)
169 		digit >>= 1;
170 
171 	return ((num_digits - 1) * 64 + i);
172 }
173 
174 /* Set dest from unaligned bit string src. */
175 void vli_from_be64(u64 *dest, const void *src, unsigned int ndigits)
176 {
177 	int i;
178 	const u64 *from = src;
179 
180 	for (i = 0; i < ndigits; i++)
181 		dest[i] = get_unaligned_be64(&from[ndigits - 1 - i]);
182 }
183 EXPORT_SYMBOL(vli_from_be64);
184 
185 void vli_from_le64(u64 *dest, const void *src, unsigned int ndigits)
186 {
187 	int i;
188 	const u64 *from = src;
189 
190 	for (i = 0; i < ndigits; i++)
191 		dest[i] = get_unaligned_le64(&from[i]);
192 }
193 EXPORT_SYMBOL(vli_from_le64);
194 
195 /* Sets dest = src. */
196 static void vli_set(u64 *dest, const u64 *src, unsigned int ndigits)
197 {
198 	int i;
199 
200 	for (i = 0; i < ndigits; i++)
201 		dest[i] = src[i];
202 }
203 
204 /* Returns sign of left - right. */
205 int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits)
206 {
207 	int i;
208 
209 	for (i = ndigits - 1; i >= 0; i--) {
210 		if (left[i] > right[i])
211 			return 1;
212 		else if (left[i] < right[i])
213 			return -1;
214 	}
215 
216 	return 0;
217 }
218 EXPORT_SYMBOL(vli_cmp);
219 
220 /* Computes result = in << c, returning carry. Can modify in place
221  * (if result == in). 0 < shift < 64.
222  */
223 static u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift,
224 		      unsigned int ndigits)
225 {
226 	u64 carry = 0;
227 	int i;
228 
229 	for (i = 0; i < ndigits; i++) {
230 		u64 temp = in[i];
231 
232 		result[i] = (temp << shift) | carry;
233 		carry = temp >> (64 - shift);
234 	}
235 
236 	return carry;
237 }
238 
239 /* Computes vli = vli >> 1. */
240 static void vli_rshift1(u64 *vli, unsigned int ndigits)
241 {
242 	u64 *end = vli;
243 	u64 carry = 0;
244 
245 	vli += ndigits;
246 
247 	while (vli-- > end) {
248 		u64 temp = *vli;
249 		*vli = (temp >> 1) | carry;
250 		carry = temp << 63;
251 	}
252 }
253 
254 /* Computes result = left + right, returning carry. Can modify in place. */
255 static u64 vli_add(u64 *result, const u64 *left, const u64 *right,
256 		   unsigned int ndigits)
257 {
258 	u64 carry = 0;
259 	int i;
260 
261 	for (i = 0; i < ndigits; i++) {
262 		u64 sum;
263 
264 		sum = left[i] + right[i] + carry;
265 		if (sum != left[i])
266 			carry = (sum < left[i]);
267 
268 		result[i] = sum;
269 	}
270 
271 	return carry;
272 }
273 
274 /* Computes result = left + right, returning carry. Can modify in place. */
275 static u64 vli_uadd(u64 *result, const u64 *left, u64 right,
276 		    unsigned int ndigits)
277 {
278 	u64 carry = right;
279 	int i;
280 
281 	for (i = 0; i < ndigits; i++) {
282 		u64 sum;
283 
284 		sum = left[i] + carry;
285 		if (sum != left[i])
286 			carry = (sum < left[i]);
287 		else
288 			carry = !!carry;
289 
290 		result[i] = sum;
291 	}
292 
293 	return carry;
294 }
295 
296 /* Computes result = left - right, returning borrow. Can modify in place. */
297 u64 vli_sub(u64 *result, const u64 *left, const u64 *right,
298 		   unsigned int ndigits)
299 {
300 	u64 borrow = 0;
301 	int i;
302 
303 	for (i = 0; i < ndigits; i++) {
304 		u64 diff;
305 
306 		diff = left[i] - right[i] - borrow;
307 		if (diff != left[i])
308 			borrow = (diff > left[i]);
309 
310 		result[i] = diff;
311 	}
312 
313 	return borrow;
314 }
315 EXPORT_SYMBOL(vli_sub);
316 
317 /* Computes result = left - right, returning borrow. Can modify in place. */
318 static u64 vli_usub(u64 *result, const u64 *left, u64 right,
319 	     unsigned int ndigits)
320 {
321 	u64 borrow = right;
322 	int i;
323 
324 	for (i = 0; i < ndigits; i++) {
325 		u64 diff;
326 
327 		diff = left[i] - borrow;
328 		if (diff != left[i])
329 			borrow = (diff > left[i]);
330 
331 		result[i] = diff;
332 	}
333 
334 	return borrow;
335 }
336 
337 static uint128_t mul_64_64(u64 left, u64 right)
338 {
339 	uint128_t result;
340 #if defined(CONFIG_ARCH_SUPPORTS_INT128)
341 	unsigned __int128 m = (unsigned __int128)left * right;
342 
343 	result.m_low  = m;
344 	result.m_high = m >> 64;
345 #else
346 	u64 a0 = left & 0xffffffffull;
347 	u64 a1 = left >> 32;
348 	u64 b0 = right & 0xffffffffull;
349 	u64 b1 = right >> 32;
350 	u64 m0 = a0 * b0;
351 	u64 m1 = a0 * b1;
352 	u64 m2 = a1 * b0;
353 	u64 m3 = a1 * b1;
354 
355 	m2 += (m0 >> 32);
356 	m2 += m1;
357 
358 	/* Overflow */
359 	if (m2 < m1)
360 		m3 += 0x100000000ull;
361 
362 	result.m_low = (m0 & 0xffffffffull) | (m2 << 32);
363 	result.m_high = m3 + (m2 >> 32);
364 #endif
365 	return result;
366 }
367 
368 static uint128_t add_128_128(uint128_t a, uint128_t b)
369 {
370 	uint128_t result;
371 
372 	result.m_low = a.m_low + b.m_low;
373 	result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low);
374 
375 	return result;
376 }
377 
378 static void vli_mult(u64 *result, const u64 *left, const u64 *right,
379 		     unsigned int ndigits)
380 {
381 	uint128_t r01 = { 0, 0 };
382 	u64 r2 = 0;
383 	unsigned int i, k;
384 
385 	/* Compute each digit of result in sequence, maintaining the
386 	 * carries.
387 	 */
388 	for (k = 0; k < ndigits * 2 - 1; k++) {
389 		unsigned int min;
390 
391 		if (k < ndigits)
392 			min = 0;
393 		else
394 			min = (k + 1) - ndigits;
395 
396 		for (i = min; i <= k && i < ndigits; i++) {
397 			uint128_t product;
398 
399 			product = mul_64_64(left[i], right[k - i]);
400 
401 			r01 = add_128_128(r01, product);
402 			r2 += (r01.m_high < product.m_high);
403 		}
404 
405 		result[k] = r01.m_low;
406 		r01.m_low = r01.m_high;
407 		r01.m_high = r2;
408 		r2 = 0;
409 	}
410 
411 	result[ndigits * 2 - 1] = r01.m_low;
412 }
413 
414 /* Compute product = left * right, for a small right value. */
415 static void vli_umult(u64 *result, const u64 *left, u32 right,
416 		      unsigned int ndigits)
417 {
418 	uint128_t r01 = { 0 };
419 	unsigned int k;
420 
421 	for (k = 0; k < ndigits; k++) {
422 		uint128_t product;
423 
424 		product = mul_64_64(left[k], right);
425 		r01 = add_128_128(r01, product);
426 		/* no carry */
427 		result[k] = r01.m_low;
428 		r01.m_low = r01.m_high;
429 		r01.m_high = 0;
430 	}
431 	result[k] = r01.m_low;
432 	for (++k; k < ndigits * 2; k++)
433 		result[k] = 0;
434 }
435 
436 static void vli_square(u64 *result, const u64 *left, unsigned int ndigits)
437 {
438 	uint128_t r01 = { 0, 0 };
439 	u64 r2 = 0;
440 	int i, k;
441 
442 	for (k = 0; k < ndigits * 2 - 1; k++) {
443 		unsigned int min;
444 
445 		if (k < ndigits)
446 			min = 0;
447 		else
448 			min = (k + 1) - ndigits;
449 
450 		for (i = min; i <= k && i <= k - i; i++) {
451 			uint128_t product;
452 
453 			product = mul_64_64(left[i], left[k - i]);
454 
455 			if (i < k - i) {
456 				r2 += product.m_high >> 63;
457 				product.m_high = (product.m_high << 1) |
458 						 (product.m_low >> 63);
459 				product.m_low <<= 1;
460 			}
461 
462 			r01 = add_128_128(r01, product);
463 			r2 += (r01.m_high < product.m_high);
464 		}
465 
466 		result[k] = r01.m_low;
467 		r01.m_low = r01.m_high;
468 		r01.m_high = r2;
469 		r2 = 0;
470 	}
471 
472 	result[ndigits * 2 - 1] = r01.m_low;
473 }
474 
475 /* Computes result = (left + right) % mod.
476  * Assumes that left < mod and right < mod, result != mod.
477  */
478 static void vli_mod_add(u64 *result, const u64 *left, const u64 *right,
479 			const u64 *mod, unsigned int ndigits)
480 {
481 	u64 carry;
482 
483 	carry = vli_add(result, left, right, ndigits);
484 
485 	/* result > mod (result = mod + remainder), so subtract mod to
486 	 * get remainder.
487 	 */
488 	if (carry || vli_cmp(result, mod, ndigits) >= 0)
489 		vli_sub(result, result, mod, ndigits);
490 }
491 
492 /* Computes result = (left - right) % mod.
493  * Assumes that left < mod and right < mod, result != mod.
494  */
495 static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right,
496 			const u64 *mod, unsigned int ndigits)
497 {
498 	u64 borrow = vli_sub(result, left, right, ndigits);
499 
500 	/* In this case, p_result == -diff == (max int) - diff.
501 	 * Since -x % d == d - x, we can get the correct result from
502 	 * result + mod (with overflow).
503 	 */
504 	if (borrow)
505 		vli_add(result, result, mod, ndigits);
506 }
507 
508 /*
509  * Computes result = product % mod
510  * for special form moduli: p = 2^k-c, for small c (note the minus sign)
511  *
512  * References:
513  * R. Crandall, C. Pomerance. Prime Numbers: A Computational Perspective.
514  * 9 Fast Algorithms for Large-Integer Arithmetic. 9.2.3 Moduli of special form
515  * Algorithm 9.2.13 (Fast mod operation for special-form moduli).
516  */
517 static void vli_mmod_special(u64 *result, const u64 *product,
518 			      const u64 *mod, unsigned int ndigits)
519 {
520 	u64 c = -mod[0];
521 	u64 t[ECC_MAX_DIGITS * 2];
522 	u64 r[ECC_MAX_DIGITS * 2];
523 
524 	vli_set(r, product, ndigits * 2);
525 	while (!vli_is_zero(r + ndigits, ndigits)) {
526 		vli_umult(t, r + ndigits, c, ndigits);
527 		vli_clear(r + ndigits, ndigits);
528 		vli_add(r, r, t, ndigits * 2);
529 	}
530 	vli_set(t, mod, ndigits);
531 	vli_clear(t + ndigits, ndigits);
532 	while (vli_cmp(r, t, ndigits * 2) >= 0)
533 		vli_sub(r, r, t, ndigits * 2);
534 	vli_set(result, r, ndigits);
535 }
536 
537 /*
538  * Computes result = product % mod
539  * for special form moduli: p = 2^{k-1}+c, for small c (note the plus sign)
540  * where k-1 does not fit into qword boundary by -1 bit (such as 255).
541 
542  * References (loosely based on):
543  * A. Menezes, P. van Oorschot, S. Vanstone. Handbook of Applied Cryptography.
544  * 14.3.4 Reduction methods for moduli of special form. Algorithm 14.47.
545  * URL: http://cacr.uwaterloo.ca/hac/about/chap14.pdf
546  *
547  * H. Cohen, G. Frey, R. Avanzi, C. Doche, T. Lange, K. Nguyen, F. Vercauteren.
548  * Handbook of Elliptic and Hyperelliptic Curve Cryptography.
549  * Algorithm 10.25 Fast reduction for special form moduli
550  */
551 static void vli_mmod_special2(u64 *result, const u64 *product,
552 			       const u64 *mod, unsigned int ndigits)
553 {
554 	u64 c2 = mod[0] * 2;
555 	u64 q[ECC_MAX_DIGITS];
556 	u64 r[ECC_MAX_DIGITS * 2];
557 	u64 m[ECC_MAX_DIGITS * 2]; /* expanded mod */
558 	int carry; /* last bit that doesn't fit into q */
559 	int i;
560 
561 	vli_set(m, mod, ndigits);
562 	vli_clear(m + ndigits, ndigits);
563 
564 	vli_set(r, product, ndigits);
565 	/* q and carry are top bits */
566 	vli_set(q, product + ndigits, ndigits);
567 	vli_clear(r + ndigits, ndigits);
568 	carry = vli_is_negative(r, ndigits);
569 	if (carry)
570 		r[ndigits - 1] &= (1ull << 63) - 1;
571 	for (i = 1; carry || !vli_is_zero(q, ndigits); i++) {
572 		u64 qc[ECC_MAX_DIGITS * 2];
573 
574 		vli_umult(qc, q, c2, ndigits);
575 		if (carry)
576 			vli_uadd(qc, qc, mod[0], ndigits * 2);
577 		vli_set(q, qc + ndigits, ndigits);
578 		vli_clear(qc + ndigits, ndigits);
579 		carry = vli_is_negative(qc, ndigits);
580 		if (carry)
581 			qc[ndigits - 1] &= (1ull << 63) - 1;
582 		if (i & 1)
583 			vli_sub(r, r, qc, ndigits * 2);
584 		else
585 			vli_add(r, r, qc, ndigits * 2);
586 	}
587 	while (vli_is_negative(r, ndigits * 2))
588 		vli_add(r, r, m, ndigits * 2);
589 	while (vli_cmp(r, m, ndigits * 2) >= 0)
590 		vli_sub(r, r, m, ndigits * 2);
591 
592 	vli_set(result, r, ndigits);
593 }
594 
595 /*
596  * Computes result = product % mod, where product is 2N words long.
597  * Reference: Ken MacKay's micro-ecc.
598  * Currently only designed to work for curve_p or curve_n.
599  */
600 static void vli_mmod_slow(u64 *result, u64 *product, const u64 *mod,
601 			  unsigned int ndigits)
602 {
603 	u64 mod_m[2 * ECC_MAX_DIGITS];
604 	u64 tmp[2 * ECC_MAX_DIGITS];
605 	u64 *v[2] = { tmp, product };
606 	u64 carry = 0;
607 	unsigned int i;
608 	/* Shift mod so its highest set bit is at the maximum position. */
609 	int shift = (ndigits * 2 * 64) - vli_num_bits(mod, ndigits);
610 	int word_shift = shift / 64;
611 	int bit_shift = shift % 64;
612 
613 	vli_clear(mod_m, word_shift);
614 	if (bit_shift > 0) {
615 		for (i = 0; i < ndigits; ++i) {
616 			mod_m[word_shift + i] = (mod[i] << bit_shift) | carry;
617 			carry = mod[i] >> (64 - bit_shift);
618 		}
619 	} else
620 		vli_set(mod_m + word_shift, mod, ndigits);
621 
622 	for (i = 1; shift >= 0; --shift) {
623 		u64 borrow = 0;
624 		unsigned int j;
625 
626 		for (j = 0; j < ndigits * 2; ++j) {
627 			u64 diff = v[i][j] - mod_m[j] - borrow;
628 
629 			if (diff != v[i][j])
630 				borrow = (diff > v[i][j]);
631 			v[1 - i][j] = diff;
632 		}
633 		i = !(i ^ borrow); /* Swap the index if there was no borrow */
634 		vli_rshift1(mod_m, ndigits);
635 		mod_m[ndigits - 1] |= mod_m[ndigits] << (64 - 1);
636 		vli_rshift1(mod_m + ndigits, ndigits);
637 	}
638 	vli_set(result, v[i], ndigits);
639 }
640 
641 /* Computes result = product % mod using Barrett's reduction with precomputed
642  * value mu appended to the mod after ndigits, mu = (2^{2w} / mod) and have
643  * length ndigits + 1, where mu * (2^w - 1) should not overflow ndigits
644  * boundary.
645  *
646  * Reference:
647  * R. Brent, P. Zimmermann. Modern Computer Arithmetic. 2010.
648  * 2.4.1 Barrett's algorithm. Algorithm 2.5.
649  */
650 static void vli_mmod_barrett(u64 *result, u64 *product, const u64 *mod,
651 			     unsigned int ndigits)
652 {
653 	u64 q[ECC_MAX_DIGITS * 2];
654 	u64 r[ECC_MAX_DIGITS * 2];
655 	const u64 *mu = mod + ndigits;
656 
657 	vli_mult(q, product + ndigits, mu, ndigits);
658 	if (mu[ndigits])
659 		vli_add(q + ndigits, q + ndigits, product + ndigits, ndigits);
660 	vli_mult(r, mod, q + ndigits, ndigits);
661 	vli_sub(r, product, r, ndigits * 2);
662 	while (!vli_is_zero(r + ndigits, ndigits) ||
663 	       vli_cmp(r, mod, ndigits) != -1) {
664 		u64 carry;
665 
666 		carry = vli_sub(r, r, mod, ndigits);
667 		vli_usub(r + ndigits, r + ndigits, carry, ndigits);
668 	}
669 	vli_set(result, r, ndigits);
670 }
671 
672 /* Computes p_result = p_product % curve_p.
673  * See algorithm 5 and 6 from
674  * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf
675  */
676 static void vli_mmod_fast_192(u64 *result, const u64 *product,
677 			      const u64 *curve_prime, u64 *tmp)
678 {
679 	const unsigned int ndigits = 3;
680 	int carry;
681 
682 	vli_set(result, product, ndigits);
683 
684 	vli_set(tmp, &product[3], ndigits);
685 	carry = vli_add(result, result, tmp, ndigits);
686 
687 	tmp[0] = 0;
688 	tmp[1] = product[3];
689 	tmp[2] = product[4];
690 	carry += vli_add(result, result, tmp, ndigits);
691 
692 	tmp[0] = tmp[1] = product[5];
693 	tmp[2] = 0;
694 	carry += vli_add(result, result, tmp, ndigits);
695 
696 	while (carry || vli_cmp(curve_prime, result, ndigits) != 1)
697 		carry -= vli_sub(result, result, curve_prime, ndigits);
698 }
699 
700 /* Computes result = product % curve_prime
701  * from http://www.nsa.gov/ia/_files/nist-routines.pdf
702  */
703 static void vli_mmod_fast_256(u64 *result, const u64 *product,
704 			      const u64 *curve_prime, u64 *tmp)
705 {
706 	int carry;
707 	const unsigned int ndigits = 4;
708 
709 	/* t */
710 	vli_set(result, product, ndigits);
711 
712 	/* s1 */
713 	tmp[0] = 0;
714 	tmp[1] = product[5] & 0xffffffff00000000ull;
715 	tmp[2] = product[6];
716 	tmp[3] = product[7];
717 	carry = vli_lshift(tmp, tmp, 1, ndigits);
718 	carry += vli_add(result, result, tmp, ndigits);
719 
720 	/* s2 */
721 	tmp[1] = product[6] << 32;
722 	tmp[2] = (product[6] >> 32) | (product[7] << 32);
723 	tmp[3] = product[7] >> 32;
724 	carry += vli_lshift(tmp, tmp, 1, ndigits);
725 	carry += vli_add(result, result, tmp, ndigits);
726 
727 	/* s3 */
728 	tmp[0] = product[4];
729 	tmp[1] = product[5] & 0xffffffff;
730 	tmp[2] = 0;
731 	tmp[3] = product[7];
732 	carry += vli_add(result, result, tmp, ndigits);
733 
734 	/* s4 */
735 	tmp[0] = (product[4] >> 32) | (product[5] << 32);
736 	tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull);
737 	tmp[2] = product[7];
738 	tmp[3] = (product[6] >> 32) | (product[4] << 32);
739 	carry += vli_add(result, result, tmp, ndigits);
740 
741 	/* d1 */
742 	tmp[0] = (product[5] >> 32) | (product[6] << 32);
743 	tmp[1] = (product[6] >> 32);
744 	tmp[2] = 0;
745 	tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32);
746 	carry -= vli_sub(result, result, tmp, ndigits);
747 
748 	/* d2 */
749 	tmp[0] = product[6];
750 	tmp[1] = product[7];
751 	tmp[2] = 0;
752 	tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull);
753 	carry -= vli_sub(result, result, tmp, ndigits);
754 
755 	/* d3 */
756 	tmp[0] = (product[6] >> 32) | (product[7] << 32);
757 	tmp[1] = (product[7] >> 32) | (product[4] << 32);
758 	tmp[2] = (product[4] >> 32) | (product[5] << 32);
759 	tmp[3] = (product[6] << 32);
760 	carry -= vli_sub(result, result, tmp, ndigits);
761 
762 	/* d4 */
763 	tmp[0] = product[7];
764 	tmp[1] = product[4] & 0xffffffff00000000ull;
765 	tmp[2] = product[5];
766 	tmp[3] = product[6] & 0xffffffff00000000ull;
767 	carry -= vli_sub(result, result, tmp, ndigits);
768 
769 	if (carry < 0) {
770 		do {
771 			carry += vli_add(result, result, curve_prime, ndigits);
772 		} while (carry < 0);
773 	} else {
774 		while (carry || vli_cmp(curve_prime, result, ndigits) != 1)
775 			carry -= vli_sub(result, result, curve_prime, ndigits);
776 	}
777 }
778 
779 /* Computes result = product % curve_prime for different curve_primes.
780  *
781  * Note that curve_primes are distinguished just by heuristic check and
782  * not by complete conformance check.
783  */
784 static bool vli_mmod_fast(u64 *result, u64 *product,
785 			  const u64 *curve_prime, unsigned int ndigits)
786 {
787 	u64 tmp[2 * ECC_MAX_DIGITS];
788 
789 	/* Currently, both NIST primes have -1 in lowest qword. */
790 	if (curve_prime[0] != -1ull) {
791 		/* Try to handle Pseudo-Marsenne primes. */
792 		if (curve_prime[ndigits - 1] == -1ull) {
793 			vli_mmod_special(result, product, curve_prime,
794 					 ndigits);
795 			return true;
796 		} else if (curve_prime[ndigits - 1] == 1ull << 63 &&
797 			   curve_prime[ndigits - 2] == 0) {
798 			vli_mmod_special2(result, product, curve_prime,
799 					  ndigits);
800 			return true;
801 		}
802 		vli_mmod_barrett(result, product, curve_prime, ndigits);
803 		return true;
804 	}
805 
806 	switch (ndigits) {
807 	case 3:
808 		vli_mmod_fast_192(result, product, curve_prime, tmp);
809 		break;
810 	case 4:
811 		vli_mmod_fast_256(result, product, curve_prime, tmp);
812 		break;
813 	default:
814 		pr_err_ratelimited("ecc: unsupported digits size!\n");
815 		return false;
816 	}
817 
818 	return true;
819 }
820 
821 /* Computes result = (left * right) % mod.
822  * Assumes that mod is big enough curve order.
823  */
824 void vli_mod_mult_slow(u64 *result, const u64 *left, const u64 *right,
825 		       const u64 *mod, unsigned int ndigits)
826 {
827 	u64 product[ECC_MAX_DIGITS * 2];
828 
829 	vli_mult(product, left, right, ndigits);
830 	vli_mmod_slow(result, product, mod, ndigits);
831 }
832 EXPORT_SYMBOL(vli_mod_mult_slow);
833 
834 /* Computes result = (left * right) % curve_prime. */
835 static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right,
836 			      const u64 *curve_prime, unsigned int ndigits)
837 {
838 	u64 product[2 * ECC_MAX_DIGITS];
839 
840 	vli_mult(product, left, right, ndigits);
841 	vli_mmod_fast(result, product, curve_prime, ndigits);
842 }
843 
844 /* Computes result = left^2 % curve_prime. */
845 static void vli_mod_square_fast(u64 *result, const u64 *left,
846 				const u64 *curve_prime, unsigned int ndigits)
847 {
848 	u64 product[2 * ECC_MAX_DIGITS];
849 
850 	vli_square(product, left, ndigits);
851 	vli_mmod_fast(result, product, curve_prime, ndigits);
852 }
853 
854 #define EVEN(vli) (!(vli[0] & 1))
855 /* Computes result = (1 / p_input) % mod. All VLIs are the same size.
856  * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide"
857  * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf
858  */
859 void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod,
860 			unsigned int ndigits)
861 {
862 	u64 a[ECC_MAX_DIGITS], b[ECC_MAX_DIGITS];
863 	u64 u[ECC_MAX_DIGITS], v[ECC_MAX_DIGITS];
864 	u64 carry;
865 	int cmp_result;
866 
867 	if (vli_is_zero(input, ndigits)) {
868 		vli_clear(result, ndigits);
869 		return;
870 	}
871 
872 	vli_set(a, input, ndigits);
873 	vli_set(b, mod, ndigits);
874 	vli_clear(u, ndigits);
875 	u[0] = 1;
876 	vli_clear(v, ndigits);
877 
878 	while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) {
879 		carry = 0;
880 
881 		if (EVEN(a)) {
882 			vli_rshift1(a, ndigits);
883 
884 			if (!EVEN(u))
885 				carry = vli_add(u, u, mod, ndigits);
886 
887 			vli_rshift1(u, ndigits);
888 			if (carry)
889 				u[ndigits - 1] |= 0x8000000000000000ull;
890 		} else if (EVEN(b)) {
891 			vli_rshift1(b, ndigits);
892 
893 			if (!EVEN(v))
894 				carry = vli_add(v, v, mod, ndigits);
895 
896 			vli_rshift1(v, ndigits);
897 			if (carry)
898 				v[ndigits - 1] |= 0x8000000000000000ull;
899 		} else if (cmp_result > 0) {
900 			vli_sub(a, a, b, ndigits);
901 			vli_rshift1(a, ndigits);
902 
903 			if (vli_cmp(u, v, ndigits) < 0)
904 				vli_add(u, u, mod, ndigits);
905 
906 			vli_sub(u, u, v, ndigits);
907 			if (!EVEN(u))
908 				carry = vli_add(u, u, mod, ndigits);
909 
910 			vli_rshift1(u, ndigits);
911 			if (carry)
912 				u[ndigits - 1] |= 0x8000000000000000ull;
913 		} else {
914 			vli_sub(b, b, a, ndigits);
915 			vli_rshift1(b, ndigits);
916 
917 			if (vli_cmp(v, u, ndigits) < 0)
918 				vli_add(v, v, mod, ndigits);
919 
920 			vli_sub(v, v, u, ndigits);
921 			if (!EVEN(v))
922 				carry = vli_add(v, v, mod, ndigits);
923 
924 			vli_rshift1(v, ndigits);
925 			if (carry)
926 				v[ndigits - 1] |= 0x8000000000000000ull;
927 		}
928 	}
929 
930 	vli_set(result, u, ndigits);
931 }
932 EXPORT_SYMBOL(vli_mod_inv);
933 
934 /* ------ Point operations ------ */
935 
936 /* Returns true if p_point is the point at infinity, false otherwise. */
937 static bool ecc_point_is_zero(const struct ecc_point *point)
938 {
939 	return (vli_is_zero(point->x, point->ndigits) &&
940 		vli_is_zero(point->y, point->ndigits));
941 }
942 
943 /* Point multiplication algorithm using Montgomery's ladder with co-Z
944  * coordinates. From https://eprint.iacr.org/2011/338.pdf
945  */
946 
947 /* Double in place */
948 static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1,
949 				      u64 *curve_prime, unsigned int ndigits)
950 {
951 	/* t1 = x, t2 = y, t3 = z */
952 	u64 t4[ECC_MAX_DIGITS];
953 	u64 t5[ECC_MAX_DIGITS];
954 
955 	if (vli_is_zero(z1, ndigits))
956 		return;
957 
958 	/* t4 = y1^2 */
959 	vli_mod_square_fast(t4, y1, curve_prime, ndigits);
960 	/* t5 = x1*y1^2 = A */
961 	vli_mod_mult_fast(t5, x1, t4, curve_prime, ndigits);
962 	/* t4 = y1^4 */
963 	vli_mod_square_fast(t4, t4, curve_prime, ndigits);
964 	/* t2 = y1*z1 = z3 */
965 	vli_mod_mult_fast(y1, y1, z1, curve_prime, ndigits);
966 	/* t3 = z1^2 */
967 	vli_mod_square_fast(z1, z1, curve_prime, ndigits);
968 
969 	/* t1 = x1 + z1^2 */
970 	vli_mod_add(x1, x1, z1, curve_prime, ndigits);
971 	/* t3 = 2*z1^2 */
972 	vli_mod_add(z1, z1, z1, curve_prime, ndigits);
973 	/* t3 = x1 - z1^2 */
974 	vli_mod_sub(z1, x1, z1, curve_prime, ndigits);
975 	/* t1 = x1^2 - z1^4 */
976 	vli_mod_mult_fast(x1, x1, z1, curve_prime, ndigits);
977 
978 	/* t3 = 2*(x1^2 - z1^4) */
979 	vli_mod_add(z1, x1, x1, curve_prime, ndigits);
980 	/* t1 = 3*(x1^2 - z1^4) */
981 	vli_mod_add(x1, x1, z1, curve_prime, ndigits);
982 	if (vli_test_bit(x1, 0)) {
983 		u64 carry = vli_add(x1, x1, curve_prime, ndigits);
984 
985 		vli_rshift1(x1, ndigits);
986 		x1[ndigits - 1] |= carry << 63;
987 	} else {
988 		vli_rshift1(x1, ndigits);
989 	}
990 	/* t1 = 3/2*(x1^2 - z1^4) = B */
991 
992 	/* t3 = B^2 */
993 	vli_mod_square_fast(z1, x1, curve_prime, ndigits);
994 	/* t3 = B^2 - A */
995 	vli_mod_sub(z1, z1, t5, curve_prime, ndigits);
996 	/* t3 = B^2 - 2A = x3 */
997 	vli_mod_sub(z1, z1, t5, curve_prime, ndigits);
998 	/* t5 = A - x3 */
999 	vli_mod_sub(t5, t5, z1, curve_prime, ndigits);
1000 	/* t1 = B * (A - x3) */
1001 	vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
1002 	/* t4 = B * (A - x3) - y1^4 = y3 */
1003 	vli_mod_sub(t4, x1, t4, curve_prime, ndigits);
1004 
1005 	vli_set(x1, z1, ndigits);
1006 	vli_set(z1, y1, ndigits);
1007 	vli_set(y1, t4, ndigits);
1008 }
1009 
1010 /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */
1011 static void apply_z(u64 *x1, u64 *y1, u64 *z, u64 *curve_prime,
1012 		    unsigned int ndigits)
1013 {
1014 	u64 t1[ECC_MAX_DIGITS];
1015 
1016 	vli_mod_square_fast(t1, z, curve_prime, ndigits);    /* z^2 */
1017 	vli_mod_mult_fast(x1, x1, t1, curve_prime, ndigits); /* x1 * z^2 */
1018 	vli_mod_mult_fast(t1, t1, z, curve_prime, ndigits);  /* z^3 */
1019 	vli_mod_mult_fast(y1, y1, t1, curve_prime, ndigits); /* y1 * z^3 */
1020 }
1021 
1022 /* P = (x1, y1) => 2P, (x2, y2) => P' */
1023 static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2,
1024 				u64 *p_initial_z, u64 *curve_prime,
1025 				unsigned int ndigits)
1026 {
1027 	u64 z[ECC_MAX_DIGITS];
1028 
1029 	vli_set(x2, x1, ndigits);
1030 	vli_set(y2, y1, ndigits);
1031 
1032 	vli_clear(z, ndigits);
1033 	z[0] = 1;
1034 
1035 	if (p_initial_z)
1036 		vli_set(z, p_initial_z, ndigits);
1037 
1038 	apply_z(x1, y1, z, curve_prime, ndigits);
1039 
1040 	ecc_point_double_jacobian(x1, y1, z, curve_prime, ndigits);
1041 
1042 	apply_z(x2, y2, z, curve_prime, ndigits);
1043 }
1044 
1045 /* Input P = (x1, y1, Z), Q = (x2, y2, Z)
1046  * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3)
1047  * or P => P', Q => P + Q
1048  */
1049 static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime,
1050 		     unsigned int ndigits)
1051 {
1052 	/* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
1053 	u64 t5[ECC_MAX_DIGITS];
1054 
1055 	/* t5 = x2 - x1 */
1056 	vli_mod_sub(t5, x2, x1, curve_prime, ndigits);
1057 	/* t5 = (x2 - x1)^2 = A */
1058 	vli_mod_square_fast(t5, t5, curve_prime, ndigits);
1059 	/* t1 = x1*A = B */
1060 	vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
1061 	/* t3 = x2*A = C */
1062 	vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits);
1063 	/* t4 = y2 - y1 */
1064 	vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
1065 	/* t5 = (y2 - y1)^2 = D */
1066 	vli_mod_square_fast(t5, y2, curve_prime, ndigits);
1067 
1068 	/* t5 = D - B */
1069 	vli_mod_sub(t5, t5, x1, curve_prime, ndigits);
1070 	/* t5 = D - B - C = x3 */
1071 	vli_mod_sub(t5, t5, x2, curve_prime, ndigits);
1072 	/* t3 = C - B */
1073 	vli_mod_sub(x2, x2, x1, curve_prime, ndigits);
1074 	/* t2 = y1*(C - B) */
1075 	vli_mod_mult_fast(y1, y1, x2, curve_prime, ndigits);
1076 	/* t3 = B - x3 */
1077 	vli_mod_sub(x2, x1, t5, curve_prime, ndigits);
1078 	/* t4 = (y2 - y1)*(B - x3) */
1079 	vli_mod_mult_fast(y2, y2, x2, curve_prime, ndigits);
1080 	/* t4 = y3 */
1081 	vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
1082 
1083 	vli_set(x2, t5, ndigits);
1084 }
1085 
1086 /* Input P = (x1, y1, Z), Q = (x2, y2, Z)
1087  * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3)
1088  * or P => P - Q, Q => P + Q
1089  */
1090 static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime,
1091 		       unsigned int ndigits)
1092 {
1093 	/* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
1094 	u64 t5[ECC_MAX_DIGITS];
1095 	u64 t6[ECC_MAX_DIGITS];
1096 	u64 t7[ECC_MAX_DIGITS];
1097 
1098 	/* t5 = x2 - x1 */
1099 	vli_mod_sub(t5, x2, x1, curve_prime, ndigits);
1100 	/* t5 = (x2 - x1)^2 = A */
1101 	vli_mod_square_fast(t5, t5, curve_prime, ndigits);
1102 	/* t1 = x1*A = B */
1103 	vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
1104 	/* t3 = x2*A = C */
1105 	vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits);
1106 	/* t4 = y2 + y1 */
1107 	vli_mod_add(t5, y2, y1, curve_prime, ndigits);
1108 	/* t4 = y2 - y1 */
1109 	vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
1110 
1111 	/* t6 = C - B */
1112 	vli_mod_sub(t6, x2, x1, curve_prime, ndigits);
1113 	/* t2 = y1 * (C - B) */
1114 	vli_mod_mult_fast(y1, y1, t6, curve_prime, ndigits);
1115 	/* t6 = B + C */
1116 	vli_mod_add(t6, x1, x2, curve_prime, ndigits);
1117 	/* t3 = (y2 - y1)^2 */
1118 	vli_mod_square_fast(x2, y2, curve_prime, ndigits);
1119 	/* t3 = x3 */
1120 	vli_mod_sub(x2, x2, t6, curve_prime, ndigits);
1121 
1122 	/* t7 = B - x3 */
1123 	vli_mod_sub(t7, x1, x2, curve_prime, ndigits);
1124 	/* t4 = (y2 - y1)*(B - x3) */
1125 	vli_mod_mult_fast(y2, y2, t7, curve_prime, ndigits);
1126 	/* t4 = y3 */
1127 	vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
1128 
1129 	/* t7 = (y2 + y1)^2 = F */
1130 	vli_mod_square_fast(t7, t5, curve_prime, ndigits);
1131 	/* t7 = x3' */
1132 	vli_mod_sub(t7, t7, t6, curve_prime, ndigits);
1133 	/* t6 = x3' - B */
1134 	vli_mod_sub(t6, t7, x1, curve_prime, ndigits);
1135 	/* t6 = (y2 + y1)*(x3' - B) */
1136 	vli_mod_mult_fast(t6, t6, t5, curve_prime, ndigits);
1137 	/* t2 = y3' */
1138 	vli_mod_sub(y1, t6, y1, curve_prime, ndigits);
1139 
1140 	vli_set(x1, t7, ndigits);
1141 }
1142 
1143 static void ecc_point_mult(struct ecc_point *result,
1144 			   const struct ecc_point *point, const u64 *scalar,
1145 			   u64 *initial_z, const struct ecc_curve *curve,
1146 			   unsigned int ndigits)
1147 {
1148 	/* R0 and R1 */
1149 	u64 rx[2][ECC_MAX_DIGITS];
1150 	u64 ry[2][ECC_MAX_DIGITS];
1151 	u64 z[ECC_MAX_DIGITS];
1152 	u64 sk[2][ECC_MAX_DIGITS];
1153 	u64 *curve_prime = curve->p;
1154 	int i, nb;
1155 	int num_bits;
1156 	int carry;
1157 
1158 	carry = vli_add(sk[0], scalar, curve->n, ndigits);
1159 	vli_add(sk[1], sk[0], curve->n, ndigits);
1160 	scalar = sk[!carry];
1161 	num_bits = sizeof(u64) * ndigits * 8 + 1;
1162 
1163 	vli_set(rx[1], point->x, ndigits);
1164 	vli_set(ry[1], point->y, ndigits);
1165 
1166 	xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve_prime,
1167 			    ndigits);
1168 
1169 	for (i = num_bits - 2; i > 0; i--) {
1170 		nb = !vli_test_bit(scalar, i);
1171 		xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime,
1172 			   ndigits);
1173 		xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime,
1174 			 ndigits);
1175 	}
1176 
1177 	nb = !vli_test_bit(scalar, 0);
1178 	xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime,
1179 		   ndigits);
1180 
1181 	/* Find final 1/Z value. */
1182 	/* X1 - X0 */
1183 	vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits);
1184 	/* Yb * (X1 - X0) */
1185 	vli_mod_mult_fast(z, z, ry[1 - nb], curve_prime, ndigits);
1186 	/* xP * Yb * (X1 - X0) */
1187 	vli_mod_mult_fast(z, z, point->x, curve_prime, ndigits);
1188 
1189 	/* 1 / (xP * Yb * (X1 - X0)) */
1190 	vli_mod_inv(z, z, curve_prime, point->ndigits);
1191 
1192 	/* yP / (xP * Yb * (X1 - X0)) */
1193 	vli_mod_mult_fast(z, z, point->y, curve_prime, ndigits);
1194 	/* Xb * yP / (xP * Yb * (X1 - X0)) */
1195 	vli_mod_mult_fast(z, z, rx[1 - nb], curve_prime, ndigits);
1196 	/* End 1/Z calculation */
1197 
1198 	xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime, ndigits);
1199 
1200 	apply_z(rx[0], ry[0], z, curve_prime, ndigits);
1201 
1202 	vli_set(result->x, rx[0], ndigits);
1203 	vli_set(result->y, ry[0], ndigits);
1204 }
1205 
1206 /* Computes R = P + Q mod p */
1207 static void ecc_point_add(const struct ecc_point *result,
1208 		   const struct ecc_point *p, const struct ecc_point *q,
1209 		   const struct ecc_curve *curve)
1210 {
1211 	u64 z[ECC_MAX_DIGITS];
1212 	u64 px[ECC_MAX_DIGITS];
1213 	u64 py[ECC_MAX_DIGITS];
1214 	unsigned int ndigits = curve->g.ndigits;
1215 
1216 	vli_set(result->x, q->x, ndigits);
1217 	vli_set(result->y, q->y, ndigits);
1218 	vli_mod_sub(z, result->x, p->x, curve->p, ndigits);
1219 	vli_set(px, p->x, ndigits);
1220 	vli_set(py, p->y, ndigits);
1221 	xycz_add(px, py, result->x, result->y, curve->p, ndigits);
1222 	vli_mod_inv(z, z, curve->p, ndigits);
1223 	apply_z(result->x, result->y, z, curve->p, ndigits);
1224 }
1225 
1226 /* Computes R = u1P + u2Q mod p using Shamir's trick.
1227  * Based on: Kenneth MacKay's micro-ecc (2014).
1228  */
1229 void ecc_point_mult_shamir(const struct ecc_point *result,
1230 			   const u64 *u1, const struct ecc_point *p,
1231 			   const u64 *u2, const struct ecc_point *q,
1232 			   const struct ecc_curve *curve)
1233 {
1234 	u64 z[ECC_MAX_DIGITS];
1235 	u64 sump[2][ECC_MAX_DIGITS];
1236 	u64 *rx = result->x;
1237 	u64 *ry = result->y;
1238 	unsigned int ndigits = curve->g.ndigits;
1239 	unsigned int num_bits;
1240 	struct ecc_point sum = ECC_POINT_INIT(sump[0], sump[1], ndigits);
1241 	const struct ecc_point *points[4];
1242 	const struct ecc_point *point;
1243 	unsigned int idx;
1244 	int i;
1245 
1246 	ecc_point_add(&sum, p, q, curve);
1247 	points[0] = NULL;
1248 	points[1] = p;
1249 	points[2] = q;
1250 	points[3] = &sum;
1251 
1252 	num_bits = max(vli_num_bits(u1, ndigits),
1253 		       vli_num_bits(u2, ndigits));
1254 	i = num_bits - 1;
1255 	idx = (!!vli_test_bit(u1, i)) | ((!!vli_test_bit(u2, i)) << 1);
1256 	point = points[idx];
1257 
1258 	vli_set(rx, point->x, ndigits);
1259 	vli_set(ry, point->y, ndigits);
1260 	vli_clear(z + 1, ndigits - 1);
1261 	z[0] = 1;
1262 
1263 	for (--i; i >= 0; i--) {
1264 		ecc_point_double_jacobian(rx, ry, z, curve->p, ndigits);
1265 		idx = (!!vli_test_bit(u1, i)) | ((!!vli_test_bit(u2, i)) << 1);
1266 		point = points[idx];
1267 		if (point) {
1268 			u64 tx[ECC_MAX_DIGITS];
1269 			u64 ty[ECC_MAX_DIGITS];
1270 			u64 tz[ECC_MAX_DIGITS];
1271 
1272 			vli_set(tx, point->x, ndigits);
1273 			vli_set(ty, point->y, ndigits);
1274 			apply_z(tx, ty, z, curve->p, ndigits);
1275 			vli_mod_sub(tz, rx, tx, curve->p, ndigits);
1276 			xycz_add(tx, ty, rx, ry, curve->p, ndigits);
1277 			vli_mod_mult_fast(z, z, tz, curve->p, ndigits);
1278 		}
1279 	}
1280 	vli_mod_inv(z, z, curve->p, ndigits);
1281 	apply_z(rx, ry, z, curve->p, ndigits);
1282 }
1283 EXPORT_SYMBOL(ecc_point_mult_shamir);
1284 
1285 static int __ecc_is_key_valid(const struct ecc_curve *curve,
1286 			      const u64 *private_key, unsigned int ndigits)
1287 {
1288 	u64 one[ECC_MAX_DIGITS] = { 1, };
1289 	u64 res[ECC_MAX_DIGITS];
1290 
1291 	if (!private_key)
1292 		return -EINVAL;
1293 
1294 	if (curve->g.ndigits != ndigits)
1295 		return -EINVAL;
1296 
1297 	/* Make sure the private key is in the range [2, n-3]. */
1298 	if (vli_cmp(one, private_key, ndigits) != -1)
1299 		return -EINVAL;
1300 	vli_sub(res, curve->n, one, ndigits);
1301 	vli_sub(res, res, one, ndigits);
1302 	if (vli_cmp(res, private_key, ndigits) != 1)
1303 		return -EINVAL;
1304 
1305 	return 0;
1306 }
1307 
1308 int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits,
1309 		     const u64 *private_key, unsigned int private_key_len)
1310 {
1311 	int nbytes;
1312 	const struct ecc_curve *curve = ecc_get_curve(curve_id);
1313 
1314 	nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
1315 
1316 	if (private_key_len != nbytes)
1317 		return -EINVAL;
1318 
1319 	return __ecc_is_key_valid(curve, private_key, ndigits);
1320 }
1321 EXPORT_SYMBOL(ecc_is_key_valid);
1322 
1323 /*
1324  * ECC private keys are generated using the method of extra random bits,
1325  * equivalent to that described in FIPS 186-4, Appendix B.4.1.
1326  *
1327  * d = (c mod(n–1)) + 1    where c is a string of random bits, 64 bits longer
1328  *                         than requested
1329  * 0 <= c mod(n-1) <= n-2  and implies that
1330  * 1 <= d <= n-1
1331  *
1332  * This method generates a private key uniformly distributed in the range
1333  * [1, n-1].
1334  */
1335 int ecc_gen_privkey(unsigned int curve_id, unsigned int ndigits, u64 *privkey)
1336 {
1337 	const struct ecc_curve *curve = ecc_get_curve(curve_id);
1338 	u64 priv[ECC_MAX_DIGITS];
1339 	unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
1340 	unsigned int nbits = vli_num_bits(curve->n, ndigits);
1341 	int err;
1342 
1343 	/* Check that N is included in Table 1 of FIPS 186-4, section 6.1.1 */
1344 	if (nbits < 160 || ndigits > ARRAY_SIZE(priv))
1345 		return -EINVAL;
1346 
1347 	/*
1348 	 * FIPS 186-4 recommends that the private key should be obtained from a
1349 	 * RBG with a security strength equal to or greater than the security
1350 	 * strength associated with N.
1351 	 *
1352 	 * The maximum security strength identified by NIST SP800-57pt1r4 for
1353 	 * ECC is 256 (N >= 512).
1354 	 *
1355 	 * This condition is met by the default RNG because it selects a favored
1356 	 * DRBG with a security strength of 256.
1357 	 */
1358 	if (crypto_get_default_rng())
1359 		return -EFAULT;
1360 
1361 	err = crypto_rng_get_bytes(crypto_default_rng, (u8 *)priv, nbytes);
1362 	crypto_put_default_rng();
1363 	if (err)
1364 		return err;
1365 
1366 	/* Make sure the private key is in the valid range. */
1367 	if (__ecc_is_key_valid(curve, priv, ndigits))
1368 		return -EINVAL;
1369 
1370 	ecc_swap_digits(priv, privkey, ndigits);
1371 
1372 	return 0;
1373 }
1374 EXPORT_SYMBOL(ecc_gen_privkey);
1375 
1376 int ecc_make_pub_key(unsigned int curve_id, unsigned int ndigits,
1377 		     const u64 *private_key, u64 *public_key)
1378 {
1379 	int ret = 0;
1380 	struct ecc_point *pk;
1381 	u64 priv[ECC_MAX_DIGITS];
1382 	const struct ecc_curve *curve = ecc_get_curve(curve_id);
1383 
1384 	if (!private_key || !curve || ndigits > ARRAY_SIZE(priv)) {
1385 		ret = -EINVAL;
1386 		goto out;
1387 	}
1388 
1389 	ecc_swap_digits(private_key, priv, ndigits);
1390 
1391 	pk = ecc_alloc_point(ndigits);
1392 	if (!pk) {
1393 		ret = -ENOMEM;
1394 		goto out;
1395 	}
1396 
1397 	ecc_point_mult(pk, &curve->g, priv, NULL, curve, ndigits);
1398 
1399 	/* SP800-56A rev 3 5.6.2.1.3 key check */
1400 	if (ecc_is_pubkey_valid_full(curve, pk)) {
1401 		ret = -EAGAIN;
1402 		goto err_free_point;
1403 	}
1404 
1405 	ecc_swap_digits(pk->x, public_key, ndigits);
1406 	ecc_swap_digits(pk->y, &public_key[ndigits], ndigits);
1407 
1408 err_free_point:
1409 	ecc_free_point(pk);
1410 out:
1411 	return ret;
1412 }
1413 EXPORT_SYMBOL(ecc_make_pub_key);
1414 
1415 /* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */
1416 int ecc_is_pubkey_valid_partial(const struct ecc_curve *curve,
1417 				struct ecc_point *pk)
1418 {
1419 	u64 yy[ECC_MAX_DIGITS], xxx[ECC_MAX_DIGITS], w[ECC_MAX_DIGITS];
1420 
1421 	if (WARN_ON(pk->ndigits != curve->g.ndigits))
1422 		return -EINVAL;
1423 
1424 	/* Check 1: Verify key is not the zero point. */
1425 	if (ecc_point_is_zero(pk))
1426 		return -EINVAL;
1427 
1428 	/* Check 2: Verify key is in the range [1, p-1]. */
1429 	if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1)
1430 		return -EINVAL;
1431 	if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1)
1432 		return -EINVAL;
1433 
1434 	/* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */
1435 	vli_mod_square_fast(yy, pk->y, curve->p, pk->ndigits); /* y^2 */
1436 	vli_mod_square_fast(xxx, pk->x, curve->p, pk->ndigits); /* x^2 */
1437 	vli_mod_mult_fast(xxx, xxx, pk->x, curve->p, pk->ndigits); /* x^3 */
1438 	vli_mod_mult_fast(w, curve->a, pk->x, curve->p, pk->ndigits); /* a·x */
1439 	vli_mod_add(w, w, curve->b, curve->p, pk->ndigits); /* a·x + b */
1440 	vli_mod_add(w, w, xxx, curve->p, pk->ndigits); /* x^3 + a·x + b */
1441 	if (vli_cmp(yy, w, pk->ndigits) != 0) /* Equation */
1442 		return -EINVAL;
1443 
1444 	return 0;
1445 }
1446 EXPORT_SYMBOL(ecc_is_pubkey_valid_partial);
1447 
1448 /* SP800-56A section 5.6.2.3.3 full verification */
1449 int ecc_is_pubkey_valid_full(const struct ecc_curve *curve,
1450 			     struct ecc_point *pk)
1451 {
1452 	struct ecc_point *nQ;
1453 
1454 	/* Checks 1 through 3 */
1455 	int ret = ecc_is_pubkey_valid_partial(curve, pk);
1456 
1457 	if (ret)
1458 		return ret;
1459 
1460 	/* Check 4: Verify that nQ is the zero point. */
1461 	nQ = ecc_alloc_point(pk->ndigits);
1462 	if (!nQ)
1463 		return -ENOMEM;
1464 
1465 	ecc_point_mult(nQ, pk, curve->n, NULL, curve, pk->ndigits);
1466 	if (!ecc_point_is_zero(nQ))
1467 		ret = -EINVAL;
1468 
1469 	ecc_free_point(nQ);
1470 
1471 	return ret;
1472 }
1473 EXPORT_SYMBOL(ecc_is_pubkey_valid_full);
1474 
1475 int crypto_ecdh_shared_secret(unsigned int curve_id, unsigned int ndigits,
1476 			      const u64 *private_key, const u64 *public_key,
1477 			      u64 *secret)
1478 {
1479 	int ret = 0;
1480 	struct ecc_point *product, *pk;
1481 	u64 priv[ECC_MAX_DIGITS];
1482 	u64 rand_z[ECC_MAX_DIGITS];
1483 	unsigned int nbytes;
1484 	const struct ecc_curve *curve = ecc_get_curve(curve_id);
1485 
1486 	if (!private_key || !public_key || !curve ||
1487 	    ndigits > ARRAY_SIZE(priv) || ndigits > ARRAY_SIZE(rand_z)) {
1488 		ret = -EINVAL;
1489 		goto out;
1490 	}
1491 
1492 	nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
1493 
1494 	get_random_bytes(rand_z, nbytes);
1495 
1496 	pk = ecc_alloc_point(ndigits);
1497 	if (!pk) {
1498 		ret = -ENOMEM;
1499 		goto out;
1500 	}
1501 
1502 	ecc_swap_digits(public_key, pk->x, ndigits);
1503 	ecc_swap_digits(&public_key[ndigits], pk->y, ndigits);
1504 	ret = ecc_is_pubkey_valid_partial(curve, pk);
1505 	if (ret)
1506 		goto err_alloc_product;
1507 
1508 	ecc_swap_digits(private_key, priv, ndigits);
1509 
1510 	product = ecc_alloc_point(ndigits);
1511 	if (!product) {
1512 		ret = -ENOMEM;
1513 		goto err_alloc_product;
1514 	}
1515 
1516 	ecc_point_mult(product, pk, priv, rand_z, curve, ndigits);
1517 
1518 	if (ecc_point_is_zero(product)) {
1519 		ret = -EFAULT;
1520 		goto err_validity;
1521 	}
1522 
1523 	ecc_swap_digits(product->x, secret, ndigits);
1524 
1525 err_validity:
1526 	memzero_explicit(priv, sizeof(priv));
1527 	memzero_explicit(rand_z, sizeof(rand_z));
1528 	ecc_free_point(product);
1529 err_alloc_product:
1530 	ecc_free_point(pk);
1531 out:
1532 	return ret;
1533 }
1534 EXPORT_SYMBOL(crypto_ecdh_shared_secret);
1535 
1536 MODULE_LICENSE("Dual BSD/GPL");
1537