1 /* 2 * Copyright (c) 2013, 2014 Kenneth MacKay. All rights reserved. 3 * Copyright (c) 2019 Vitaly Chikunov <vt@altlinux.org> 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are 7 * met: 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 18 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 19 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 20 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 */ 26 27 #include <crypto/ecc_curve.h> 28 #include <linux/module.h> 29 #include <linux/random.h> 30 #include <linux/slab.h> 31 #include <linux/swab.h> 32 #include <linux/fips.h> 33 #include <crypto/ecdh.h> 34 #include <crypto/rng.h> 35 #include <crypto/internal/ecc.h> 36 #include <asm/unaligned.h> 37 #include <linux/ratelimit.h> 38 39 #include "ecc_curve_defs.h" 40 41 typedef struct { 42 u64 m_low; 43 u64 m_high; 44 } uint128_t; 45 46 /* Returns curv25519 curve param */ 47 const struct ecc_curve *ecc_get_curve25519(void) 48 { 49 return &ecc_25519; 50 } 51 EXPORT_SYMBOL(ecc_get_curve25519); 52 53 const struct ecc_curve *ecc_get_curve(unsigned int curve_id) 54 { 55 switch (curve_id) { 56 /* In FIPS mode only allow P256 and higher */ 57 case ECC_CURVE_NIST_P192: 58 return fips_enabled ? NULL : &nist_p192; 59 case ECC_CURVE_NIST_P256: 60 return &nist_p256; 61 case ECC_CURVE_NIST_P384: 62 return &nist_p384; 63 default: 64 return NULL; 65 } 66 } 67 EXPORT_SYMBOL(ecc_get_curve); 68 69 static u64 *ecc_alloc_digits_space(unsigned int ndigits) 70 { 71 size_t len = ndigits * sizeof(u64); 72 73 if (!len) 74 return NULL; 75 76 return kmalloc(len, GFP_KERNEL); 77 } 78 79 static void ecc_free_digits_space(u64 *space) 80 { 81 kfree_sensitive(space); 82 } 83 84 struct ecc_point *ecc_alloc_point(unsigned int ndigits) 85 { 86 struct ecc_point *p = kmalloc(sizeof(*p), GFP_KERNEL); 87 88 if (!p) 89 return NULL; 90 91 p->x = ecc_alloc_digits_space(ndigits); 92 if (!p->x) 93 goto err_alloc_x; 94 95 p->y = ecc_alloc_digits_space(ndigits); 96 if (!p->y) 97 goto err_alloc_y; 98 99 p->ndigits = ndigits; 100 101 return p; 102 103 err_alloc_y: 104 ecc_free_digits_space(p->x); 105 err_alloc_x: 106 kfree(p); 107 return NULL; 108 } 109 EXPORT_SYMBOL(ecc_alloc_point); 110 111 void ecc_free_point(struct ecc_point *p) 112 { 113 if (!p) 114 return; 115 116 kfree_sensitive(p->x); 117 kfree_sensitive(p->y); 118 kfree_sensitive(p); 119 } 120 EXPORT_SYMBOL(ecc_free_point); 121 122 static void vli_clear(u64 *vli, unsigned int ndigits) 123 { 124 int i; 125 126 for (i = 0; i < ndigits; i++) 127 vli[i] = 0; 128 } 129 130 /* Returns true if vli == 0, false otherwise. */ 131 bool vli_is_zero(const u64 *vli, unsigned int ndigits) 132 { 133 int i; 134 135 for (i = 0; i < ndigits; i++) { 136 if (vli[i]) 137 return false; 138 } 139 140 return true; 141 } 142 EXPORT_SYMBOL(vli_is_zero); 143 144 /* Returns nonzero if bit of vli is set. */ 145 static u64 vli_test_bit(const u64 *vli, unsigned int bit) 146 { 147 return (vli[bit / 64] & ((u64)1 << (bit % 64))); 148 } 149 150 static bool vli_is_negative(const u64 *vli, unsigned int ndigits) 151 { 152 return vli_test_bit(vli, ndigits * 64 - 1); 153 } 154 155 /* Counts the number of 64-bit "digits" in vli. */ 156 static unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits) 157 { 158 int i; 159 160 /* Search from the end until we find a non-zero digit. 161 * We do it in reverse because we expect that most digits will 162 * be nonzero. 163 */ 164 for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--); 165 166 return (i + 1); 167 } 168 169 /* Counts the number of bits required for vli. */ 170 unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits) 171 { 172 unsigned int i, num_digits; 173 u64 digit; 174 175 num_digits = vli_num_digits(vli, ndigits); 176 if (num_digits == 0) 177 return 0; 178 179 digit = vli[num_digits - 1]; 180 for (i = 0; digit; i++) 181 digit >>= 1; 182 183 return ((num_digits - 1) * 64 + i); 184 } 185 EXPORT_SYMBOL(vli_num_bits); 186 187 /* Set dest from unaligned bit string src. */ 188 void vli_from_be64(u64 *dest, const void *src, unsigned int ndigits) 189 { 190 int i; 191 const u64 *from = src; 192 193 for (i = 0; i < ndigits; i++) 194 dest[i] = get_unaligned_be64(&from[ndigits - 1 - i]); 195 } 196 EXPORT_SYMBOL(vli_from_be64); 197 198 void vli_from_le64(u64 *dest, const void *src, unsigned int ndigits) 199 { 200 int i; 201 const u64 *from = src; 202 203 for (i = 0; i < ndigits; i++) 204 dest[i] = get_unaligned_le64(&from[i]); 205 } 206 EXPORT_SYMBOL(vli_from_le64); 207 208 /* Sets dest = src. */ 209 static void vli_set(u64 *dest, const u64 *src, unsigned int ndigits) 210 { 211 int i; 212 213 for (i = 0; i < ndigits; i++) 214 dest[i] = src[i]; 215 } 216 217 /* Returns sign of left - right. */ 218 int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits) 219 { 220 int i; 221 222 for (i = ndigits - 1; i >= 0; i--) { 223 if (left[i] > right[i]) 224 return 1; 225 else if (left[i] < right[i]) 226 return -1; 227 } 228 229 return 0; 230 } 231 EXPORT_SYMBOL(vli_cmp); 232 233 /* Computes result = in << c, returning carry. Can modify in place 234 * (if result == in). 0 < shift < 64. 235 */ 236 static u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift, 237 unsigned int ndigits) 238 { 239 u64 carry = 0; 240 int i; 241 242 for (i = 0; i < ndigits; i++) { 243 u64 temp = in[i]; 244 245 result[i] = (temp << shift) | carry; 246 carry = temp >> (64 - shift); 247 } 248 249 return carry; 250 } 251 252 /* Computes vli = vli >> 1. */ 253 static void vli_rshift1(u64 *vli, unsigned int ndigits) 254 { 255 u64 *end = vli; 256 u64 carry = 0; 257 258 vli += ndigits; 259 260 while (vli-- > end) { 261 u64 temp = *vli; 262 *vli = (temp >> 1) | carry; 263 carry = temp << 63; 264 } 265 } 266 267 /* Computes result = left + right, returning carry. Can modify in place. */ 268 static u64 vli_add(u64 *result, const u64 *left, const u64 *right, 269 unsigned int ndigits) 270 { 271 u64 carry = 0; 272 int i; 273 274 for (i = 0; i < ndigits; i++) { 275 u64 sum; 276 277 sum = left[i] + right[i] + carry; 278 if (sum != left[i]) 279 carry = (sum < left[i]); 280 281 result[i] = sum; 282 } 283 284 return carry; 285 } 286 287 /* Computes result = left + right, returning carry. Can modify in place. */ 288 static u64 vli_uadd(u64 *result, const u64 *left, u64 right, 289 unsigned int ndigits) 290 { 291 u64 carry = right; 292 int i; 293 294 for (i = 0; i < ndigits; i++) { 295 u64 sum; 296 297 sum = left[i] + carry; 298 if (sum != left[i]) 299 carry = (sum < left[i]); 300 else 301 carry = !!carry; 302 303 result[i] = sum; 304 } 305 306 return carry; 307 } 308 309 /* Computes result = left - right, returning borrow. Can modify in place. */ 310 u64 vli_sub(u64 *result, const u64 *left, const u64 *right, 311 unsigned int ndigits) 312 { 313 u64 borrow = 0; 314 int i; 315 316 for (i = 0; i < ndigits; i++) { 317 u64 diff; 318 319 diff = left[i] - right[i] - borrow; 320 if (diff != left[i]) 321 borrow = (diff > left[i]); 322 323 result[i] = diff; 324 } 325 326 return borrow; 327 } 328 EXPORT_SYMBOL(vli_sub); 329 330 /* Computes result = left - right, returning borrow. Can modify in place. */ 331 static u64 vli_usub(u64 *result, const u64 *left, u64 right, 332 unsigned int ndigits) 333 { 334 u64 borrow = right; 335 int i; 336 337 for (i = 0; i < ndigits; i++) { 338 u64 diff; 339 340 diff = left[i] - borrow; 341 if (diff != left[i]) 342 borrow = (diff > left[i]); 343 344 result[i] = diff; 345 } 346 347 return borrow; 348 } 349 350 static uint128_t mul_64_64(u64 left, u64 right) 351 { 352 uint128_t result; 353 #if defined(CONFIG_ARCH_SUPPORTS_INT128) 354 unsigned __int128 m = (unsigned __int128)left * right; 355 356 result.m_low = m; 357 result.m_high = m >> 64; 358 #else 359 u64 a0 = left & 0xffffffffull; 360 u64 a1 = left >> 32; 361 u64 b0 = right & 0xffffffffull; 362 u64 b1 = right >> 32; 363 u64 m0 = a0 * b0; 364 u64 m1 = a0 * b1; 365 u64 m2 = a1 * b0; 366 u64 m3 = a1 * b1; 367 368 m2 += (m0 >> 32); 369 m2 += m1; 370 371 /* Overflow */ 372 if (m2 < m1) 373 m3 += 0x100000000ull; 374 375 result.m_low = (m0 & 0xffffffffull) | (m2 << 32); 376 result.m_high = m3 + (m2 >> 32); 377 #endif 378 return result; 379 } 380 381 static uint128_t add_128_128(uint128_t a, uint128_t b) 382 { 383 uint128_t result; 384 385 result.m_low = a.m_low + b.m_low; 386 result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low); 387 388 return result; 389 } 390 391 static void vli_mult(u64 *result, const u64 *left, const u64 *right, 392 unsigned int ndigits) 393 { 394 uint128_t r01 = { 0, 0 }; 395 u64 r2 = 0; 396 unsigned int i, k; 397 398 /* Compute each digit of result in sequence, maintaining the 399 * carries. 400 */ 401 for (k = 0; k < ndigits * 2 - 1; k++) { 402 unsigned int min; 403 404 if (k < ndigits) 405 min = 0; 406 else 407 min = (k + 1) - ndigits; 408 409 for (i = min; i <= k && i < ndigits; i++) { 410 uint128_t product; 411 412 product = mul_64_64(left[i], right[k - i]); 413 414 r01 = add_128_128(r01, product); 415 r2 += (r01.m_high < product.m_high); 416 } 417 418 result[k] = r01.m_low; 419 r01.m_low = r01.m_high; 420 r01.m_high = r2; 421 r2 = 0; 422 } 423 424 result[ndigits * 2 - 1] = r01.m_low; 425 } 426 427 /* Compute product = left * right, for a small right value. */ 428 static void vli_umult(u64 *result, const u64 *left, u32 right, 429 unsigned int ndigits) 430 { 431 uint128_t r01 = { 0 }; 432 unsigned int k; 433 434 for (k = 0; k < ndigits; k++) { 435 uint128_t product; 436 437 product = mul_64_64(left[k], right); 438 r01 = add_128_128(r01, product); 439 /* no carry */ 440 result[k] = r01.m_low; 441 r01.m_low = r01.m_high; 442 r01.m_high = 0; 443 } 444 result[k] = r01.m_low; 445 for (++k; k < ndigits * 2; k++) 446 result[k] = 0; 447 } 448 449 static void vli_square(u64 *result, const u64 *left, unsigned int ndigits) 450 { 451 uint128_t r01 = { 0, 0 }; 452 u64 r2 = 0; 453 int i, k; 454 455 for (k = 0; k < ndigits * 2 - 1; k++) { 456 unsigned int min; 457 458 if (k < ndigits) 459 min = 0; 460 else 461 min = (k + 1) - ndigits; 462 463 for (i = min; i <= k && i <= k - i; i++) { 464 uint128_t product; 465 466 product = mul_64_64(left[i], left[k - i]); 467 468 if (i < k - i) { 469 r2 += product.m_high >> 63; 470 product.m_high = (product.m_high << 1) | 471 (product.m_low >> 63); 472 product.m_low <<= 1; 473 } 474 475 r01 = add_128_128(r01, product); 476 r2 += (r01.m_high < product.m_high); 477 } 478 479 result[k] = r01.m_low; 480 r01.m_low = r01.m_high; 481 r01.m_high = r2; 482 r2 = 0; 483 } 484 485 result[ndigits * 2 - 1] = r01.m_low; 486 } 487 488 /* Computes result = (left + right) % mod. 489 * Assumes that left < mod and right < mod, result != mod. 490 */ 491 static void vli_mod_add(u64 *result, const u64 *left, const u64 *right, 492 const u64 *mod, unsigned int ndigits) 493 { 494 u64 carry; 495 496 carry = vli_add(result, left, right, ndigits); 497 498 /* result > mod (result = mod + remainder), so subtract mod to 499 * get remainder. 500 */ 501 if (carry || vli_cmp(result, mod, ndigits) >= 0) 502 vli_sub(result, result, mod, ndigits); 503 } 504 505 /* Computes result = (left - right) % mod. 506 * Assumes that left < mod and right < mod, result != mod. 507 */ 508 static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right, 509 const u64 *mod, unsigned int ndigits) 510 { 511 u64 borrow = vli_sub(result, left, right, ndigits); 512 513 /* In this case, p_result == -diff == (max int) - diff. 514 * Since -x % d == d - x, we can get the correct result from 515 * result + mod (with overflow). 516 */ 517 if (borrow) 518 vli_add(result, result, mod, ndigits); 519 } 520 521 /* 522 * Computes result = product % mod 523 * for special form moduli: p = 2^k-c, for small c (note the minus sign) 524 * 525 * References: 526 * R. Crandall, C. Pomerance. Prime Numbers: A Computational Perspective. 527 * 9 Fast Algorithms for Large-Integer Arithmetic. 9.2.3 Moduli of special form 528 * Algorithm 9.2.13 (Fast mod operation for special-form moduli). 529 */ 530 static void vli_mmod_special(u64 *result, const u64 *product, 531 const u64 *mod, unsigned int ndigits) 532 { 533 u64 c = -mod[0]; 534 u64 t[ECC_MAX_DIGITS * 2]; 535 u64 r[ECC_MAX_DIGITS * 2]; 536 537 vli_set(r, product, ndigits * 2); 538 while (!vli_is_zero(r + ndigits, ndigits)) { 539 vli_umult(t, r + ndigits, c, ndigits); 540 vli_clear(r + ndigits, ndigits); 541 vli_add(r, r, t, ndigits * 2); 542 } 543 vli_set(t, mod, ndigits); 544 vli_clear(t + ndigits, ndigits); 545 while (vli_cmp(r, t, ndigits * 2) >= 0) 546 vli_sub(r, r, t, ndigits * 2); 547 vli_set(result, r, ndigits); 548 } 549 550 /* 551 * Computes result = product % mod 552 * for special form moduli: p = 2^{k-1}+c, for small c (note the plus sign) 553 * where k-1 does not fit into qword boundary by -1 bit (such as 255). 554 555 * References (loosely based on): 556 * A. Menezes, P. van Oorschot, S. Vanstone. Handbook of Applied Cryptography. 557 * 14.3.4 Reduction methods for moduli of special form. Algorithm 14.47. 558 * URL: http://cacr.uwaterloo.ca/hac/about/chap14.pdf 559 * 560 * H. Cohen, G. Frey, R. Avanzi, C. Doche, T. Lange, K. Nguyen, F. Vercauteren. 561 * Handbook of Elliptic and Hyperelliptic Curve Cryptography. 562 * Algorithm 10.25 Fast reduction for special form moduli 563 */ 564 static void vli_mmod_special2(u64 *result, const u64 *product, 565 const u64 *mod, unsigned int ndigits) 566 { 567 u64 c2 = mod[0] * 2; 568 u64 q[ECC_MAX_DIGITS]; 569 u64 r[ECC_MAX_DIGITS * 2]; 570 u64 m[ECC_MAX_DIGITS * 2]; /* expanded mod */ 571 int carry; /* last bit that doesn't fit into q */ 572 int i; 573 574 vli_set(m, mod, ndigits); 575 vli_clear(m + ndigits, ndigits); 576 577 vli_set(r, product, ndigits); 578 /* q and carry are top bits */ 579 vli_set(q, product + ndigits, ndigits); 580 vli_clear(r + ndigits, ndigits); 581 carry = vli_is_negative(r, ndigits); 582 if (carry) 583 r[ndigits - 1] &= (1ull << 63) - 1; 584 for (i = 1; carry || !vli_is_zero(q, ndigits); i++) { 585 u64 qc[ECC_MAX_DIGITS * 2]; 586 587 vli_umult(qc, q, c2, ndigits); 588 if (carry) 589 vli_uadd(qc, qc, mod[0], ndigits * 2); 590 vli_set(q, qc + ndigits, ndigits); 591 vli_clear(qc + ndigits, ndigits); 592 carry = vli_is_negative(qc, ndigits); 593 if (carry) 594 qc[ndigits - 1] &= (1ull << 63) - 1; 595 if (i & 1) 596 vli_sub(r, r, qc, ndigits * 2); 597 else 598 vli_add(r, r, qc, ndigits * 2); 599 } 600 while (vli_is_negative(r, ndigits * 2)) 601 vli_add(r, r, m, ndigits * 2); 602 while (vli_cmp(r, m, ndigits * 2) >= 0) 603 vli_sub(r, r, m, ndigits * 2); 604 605 vli_set(result, r, ndigits); 606 } 607 608 /* 609 * Computes result = product % mod, where product is 2N words long. 610 * Reference: Ken MacKay's micro-ecc. 611 * Currently only designed to work for curve_p or curve_n. 612 */ 613 static void vli_mmod_slow(u64 *result, u64 *product, const u64 *mod, 614 unsigned int ndigits) 615 { 616 u64 mod_m[2 * ECC_MAX_DIGITS]; 617 u64 tmp[2 * ECC_MAX_DIGITS]; 618 u64 *v[2] = { tmp, product }; 619 u64 carry = 0; 620 unsigned int i; 621 /* Shift mod so its highest set bit is at the maximum position. */ 622 int shift = (ndigits * 2 * 64) - vli_num_bits(mod, ndigits); 623 int word_shift = shift / 64; 624 int bit_shift = shift % 64; 625 626 vli_clear(mod_m, word_shift); 627 if (bit_shift > 0) { 628 for (i = 0; i < ndigits; ++i) { 629 mod_m[word_shift + i] = (mod[i] << bit_shift) | carry; 630 carry = mod[i] >> (64 - bit_shift); 631 } 632 } else 633 vli_set(mod_m + word_shift, mod, ndigits); 634 635 for (i = 1; shift >= 0; --shift) { 636 u64 borrow = 0; 637 unsigned int j; 638 639 for (j = 0; j < ndigits * 2; ++j) { 640 u64 diff = v[i][j] - mod_m[j] - borrow; 641 642 if (diff != v[i][j]) 643 borrow = (diff > v[i][j]); 644 v[1 - i][j] = diff; 645 } 646 i = !(i ^ borrow); /* Swap the index if there was no borrow */ 647 vli_rshift1(mod_m, ndigits); 648 mod_m[ndigits - 1] |= mod_m[ndigits] << (64 - 1); 649 vli_rshift1(mod_m + ndigits, ndigits); 650 } 651 vli_set(result, v[i], ndigits); 652 } 653 654 /* Computes result = product % mod using Barrett's reduction with precomputed 655 * value mu appended to the mod after ndigits, mu = (2^{2w} / mod) and have 656 * length ndigits + 1, where mu * (2^w - 1) should not overflow ndigits 657 * boundary. 658 * 659 * Reference: 660 * R. Brent, P. Zimmermann. Modern Computer Arithmetic. 2010. 661 * 2.4.1 Barrett's algorithm. Algorithm 2.5. 662 */ 663 static void vli_mmod_barrett(u64 *result, u64 *product, const u64 *mod, 664 unsigned int ndigits) 665 { 666 u64 q[ECC_MAX_DIGITS * 2]; 667 u64 r[ECC_MAX_DIGITS * 2]; 668 const u64 *mu = mod + ndigits; 669 670 vli_mult(q, product + ndigits, mu, ndigits); 671 if (mu[ndigits]) 672 vli_add(q + ndigits, q + ndigits, product + ndigits, ndigits); 673 vli_mult(r, mod, q + ndigits, ndigits); 674 vli_sub(r, product, r, ndigits * 2); 675 while (!vli_is_zero(r + ndigits, ndigits) || 676 vli_cmp(r, mod, ndigits) != -1) { 677 u64 carry; 678 679 carry = vli_sub(r, r, mod, ndigits); 680 vli_usub(r + ndigits, r + ndigits, carry, ndigits); 681 } 682 vli_set(result, r, ndigits); 683 } 684 685 /* Computes p_result = p_product % curve_p. 686 * See algorithm 5 and 6 from 687 * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf 688 */ 689 static void vli_mmod_fast_192(u64 *result, const u64 *product, 690 const u64 *curve_prime, u64 *tmp) 691 { 692 const unsigned int ndigits = 3; 693 int carry; 694 695 vli_set(result, product, ndigits); 696 697 vli_set(tmp, &product[3], ndigits); 698 carry = vli_add(result, result, tmp, ndigits); 699 700 tmp[0] = 0; 701 tmp[1] = product[3]; 702 tmp[2] = product[4]; 703 carry += vli_add(result, result, tmp, ndigits); 704 705 tmp[0] = tmp[1] = product[5]; 706 tmp[2] = 0; 707 carry += vli_add(result, result, tmp, ndigits); 708 709 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 710 carry -= vli_sub(result, result, curve_prime, ndigits); 711 } 712 713 /* Computes result = product % curve_prime 714 * from http://www.nsa.gov/ia/_files/nist-routines.pdf 715 */ 716 static void vli_mmod_fast_256(u64 *result, const u64 *product, 717 const u64 *curve_prime, u64 *tmp) 718 { 719 int carry; 720 const unsigned int ndigits = 4; 721 722 /* t */ 723 vli_set(result, product, ndigits); 724 725 /* s1 */ 726 tmp[0] = 0; 727 tmp[1] = product[5] & 0xffffffff00000000ull; 728 tmp[2] = product[6]; 729 tmp[3] = product[7]; 730 carry = vli_lshift(tmp, tmp, 1, ndigits); 731 carry += vli_add(result, result, tmp, ndigits); 732 733 /* s2 */ 734 tmp[1] = product[6] << 32; 735 tmp[2] = (product[6] >> 32) | (product[7] << 32); 736 tmp[3] = product[7] >> 32; 737 carry += vli_lshift(tmp, tmp, 1, ndigits); 738 carry += vli_add(result, result, tmp, ndigits); 739 740 /* s3 */ 741 tmp[0] = product[4]; 742 tmp[1] = product[5] & 0xffffffff; 743 tmp[2] = 0; 744 tmp[3] = product[7]; 745 carry += vli_add(result, result, tmp, ndigits); 746 747 /* s4 */ 748 tmp[0] = (product[4] >> 32) | (product[5] << 32); 749 tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull); 750 tmp[2] = product[7]; 751 tmp[3] = (product[6] >> 32) | (product[4] << 32); 752 carry += vli_add(result, result, tmp, ndigits); 753 754 /* d1 */ 755 tmp[0] = (product[5] >> 32) | (product[6] << 32); 756 tmp[1] = (product[6] >> 32); 757 tmp[2] = 0; 758 tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32); 759 carry -= vli_sub(result, result, tmp, ndigits); 760 761 /* d2 */ 762 tmp[0] = product[6]; 763 tmp[1] = product[7]; 764 tmp[2] = 0; 765 tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull); 766 carry -= vli_sub(result, result, tmp, ndigits); 767 768 /* d3 */ 769 tmp[0] = (product[6] >> 32) | (product[7] << 32); 770 tmp[1] = (product[7] >> 32) | (product[4] << 32); 771 tmp[2] = (product[4] >> 32) | (product[5] << 32); 772 tmp[3] = (product[6] << 32); 773 carry -= vli_sub(result, result, tmp, ndigits); 774 775 /* d4 */ 776 tmp[0] = product[7]; 777 tmp[1] = product[4] & 0xffffffff00000000ull; 778 tmp[2] = product[5]; 779 tmp[3] = product[6] & 0xffffffff00000000ull; 780 carry -= vli_sub(result, result, tmp, ndigits); 781 782 if (carry < 0) { 783 do { 784 carry += vli_add(result, result, curve_prime, ndigits); 785 } while (carry < 0); 786 } else { 787 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 788 carry -= vli_sub(result, result, curve_prime, ndigits); 789 } 790 } 791 792 #define SL32OR32(x32, y32) (((u64)x32 << 32) | y32) 793 #define AND64H(x64) (x64 & 0xffFFffFF00000000ull) 794 #define AND64L(x64) (x64 & 0x00000000ffFFffFFull) 795 796 /* Computes result = product % curve_prime 797 * from "Mathematical routines for the NIST prime elliptic curves" 798 */ 799 static void vli_mmod_fast_384(u64 *result, const u64 *product, 800 const u64 *curve_prime, u64 *tmp) 801 { 802 int carry; 803 const unsigned int ndigits = 6; 804 805 /* t */ 806 vli_set(result, product, ndigits); 807 808 /* s1 */ 809 tmp[0] = 0; // 0 || 0 810 tmp[1] = 0; // 0 || 0 811 tmp[2] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 812 tmp[3] = product[11]>>32; // 0 ||a23 813 tmp[4] = 0; // 0 || 0 814 tmp[5] = 0; // 0 || 0 815 carry = vli_lshift(tmp, tmp, 1, ndigits); 816 carry += vli_add(result, result, tmp, ndigits); 817 818 /* s2 */ 819 tmp[0] = product[6]; //a13||a12 820 tmp[1] = product[7]; //a15||a14 821 tmp[2] = product[8]; //a17||a16 822 tmp[3] = product[9]; //a19||a18 823 tmp[4] = product[10]; //a21||a20 824 tmp[5] = product[11]; //a23||a22 825 carry += vli_add(result, result, tmp, ndigits); 826 827 /* s3 */ 828 tmp[0] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 829 tmp[1] = SL32OR32(product[6], (product[11]>>32)); //a12||a23 830 tmp[2] = SL32OR32(product[7], (product[6])>>32); //a14||a13 831 tmp[3] = SL32OR32(product[8], (product[7]>>32)); //a16||a15 832 tmp[4] = SL32OR32(product[9], (product[8]>>32)); //a18||a17 833 tmp[5] = SL32OR32(product[10], (product[9]>>32)); //a20||a19 834 carry += vli_add(result, result, tmp, ndigits); 835 836 /* s4 */ 837 tmp[0] = AND64H(product[11]); //a23|| 0 838 tmp[1] = (product[10]<<32); //a20|| 0 839 tmp[2] = product[6]; //a13||a12 840 tmp[3] = product[7]; //a15||a14 841 tmp[4] = product[8]; //a17||a16 842 tmp[5] = product[9]; //a19||a18 843 carry += vli_add(result, result, tmp, ndigits); 844 845 /* s5 */ 846 tmp[0] = 0; // 0|| 0 847 tmp[1] = 0; // 0|| 0 848 tmp[2] = product[10]; //a21||a20 849 tmp[3] = product[11]; //a23||a22 850 tmp[4] = 0; // 0|| 0 851 tmp[5] = 0; // 0|| 0 852 carry += vli_add(result, result, tmp, ndigits); 853 854 /* s6 */ 855 tmp[0] = AND64L(product[10]); // 0 ||a20 856 tmp[1] = AND64H(product[10]); //a21|| 0 857 tmp[2] = product[11]; //a23||a22 858 tmp[3] = 0; // 0 || 0 859 tmp[4] = 0; // 0 || 0 860 tmp[5] = 0; // 0 || 0 861 carry += vli_add(result, result, tmp, ndigits); 862 863 /* d1 */ 864 tmp[0] = SL32OR32(product[6], (product[11]>>32)); //a12||a23 865 tmp[1] = SL32OR32(product[7], (product[6]>>32)); //a14||a13 866 tmp[2] = SL32OR32(product[8], (product[7]>>32)); //a16||a15 867 tmp[3] = SL32OR32(product[9], (product[8]>>32)); //a18||a17 868 tmp[4] = SL32OR32(product[10], (product[9]>>32)); //a20||a19 869 tmp[5] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 870 carry -= vli_sub(result, result, tmp, ndigits); 871 872 /* d2 */ 873 tmp[0] = (product[10]<<32); //a20|| 0 874 tmp[1] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 875 tmp[2] = (product[11]>>32); // 0 ||a23 876 tmp[3] = 0; // 0 || 0 877 tmp[4] = 0; // 0 || 0 878 tmp[5] = 0; // 0 || 0 879 carry -= vli_sub(result, result, tmp, ndigits); 880 881 /* d3 */ 882 tmp[0] = 0; // 0 || 0 883 tmp[1] = AND64H(product[11]); //a23|| 0 884 tmp[2] = product[11]>>32; // 0 ||a23 885 tmp[3] = 0; // 0 || 0 886 tmp[4] = 0; // 0 || 0 887 tmp[5] = 0; // 0 || 0 888 carry -= vli_sub(result, result, tmp, ndigits); 889 890 if (carry < 0) { 891 do { 892 carry += vli_add(result, result, curve_prime, ndigits); 893 } while (carry < 0); 894 } else { 895 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 896 carry -= vli_sub(result, result, curve_prime, ndigits); 897 } 898 899 } 900 901 #undef SL32OR32 902 #undef AND64H 903 #undef AND64L 904 905 /* Computes result = product % curve_prime for different curve_primes. 906 * 907 * Note that curve_primes are distinguished just by heuristic check and 908 * not by complete conformance check. 909 */ 910 static bool vli_mmod_fast(u64 *result, u64 *product, 911 const struct ecc_curve *curve) 912 { 913 u64 tmp[2 * ECC_MAX_DIGITS]; 914 const u64 *curve_prime = curve->p; 915 const unsigned int ndigits = curve->g.ndigits; 916 917 /* All NIST curves have name prefix 'nist_' */ 918 if (strncmp(curve->name, "nist_", 5) != 0) { 919 /* Try to handle Pseudo-Marsenne primes. */ 920 if (curve_prime[ndigits - 1] == -1ull) { 921 vli_mmod_special(result, product, curve_prime, 922 ndigits); 923 return true; 924 } else if (curve_prime[ndigits - 1] == 1ull << 63 && 925 curve_prime[ndigits - 2] == 0) { 926 vli_mmod_special2(result, product, curve_prime, 927 ndigits); 928 return true; 929 } 930 vli_mmod_barrett(result, product, curve_prime, ndigits); 931 return true; 932 } 933 934 switch (ndigits) { 935 case 3: 936 vli_mmod_fast_192(result, product, curve_prime, tmp); 937 break; 938 case 4: 939 vli_mmod_fast_256(result, product, curve_prime, tmp); 940 break; 941 case 6: 942 vli_mmod_fast_384(result, product, curve_prime, tmp); 943 break; 944 default: 945 pr_err_ratelimited("ecc: unsupported digits size!\n"); 946 return false; 947 } 948 949 return true; 950 } 951 952 /* Computes result = (left * right) % mod. 953 * Assumes that mod is big enough curve order. 954 */ 955 void vli_mod_mult_slow(u64 *result, const u64 *left, const u64 *right, 956 const u64 *mod, unsigned int ndigits) 957 { 958 u64 product[ECC_MAX_DIGITS * 2]; 959 960 vli_mult(product, left, right, ndigits); 961 vli_mmod_slow(result, product, mod, ndigits); 962 } 963 EXPORT_SYMBOL(vli_mod_mult_slow); 964 965 /* Computes result = (left * right) % curve_prime. */ 966 static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right, 967 const struct ecc_curve *curve) 968 { 969 u64 product[2 * ECC_MAX_DIGITS]; 970 971 vli_mult(product, left, right, curve->g.ndigits); 972 vli_mmod_fast(result, product, curve); 973 } 974 975 /* Computes result = left^2 % curve_prime. */ 976 static void vli_mod_square_fast(u64 *result, const u64 *left, 977 const struct ecc_curve *curve) 978 { 979 u64 product[2 * ECC_MAX_DIGITS]; 980 981 vli_square(product, left, curve->g.ndigits); 982 vli_mmod_fast(result, product, curve); 983 } 984 985 #define EVEN(vli) (!(vli[0] & 1)) 986 /* Computes result = (1 / p_input) % mod. All VLIs are the same size. 987 * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide" 988 * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf 989 */ 990 void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod, 991 unsigned int ndigits) 992 { 993 u64 a[ECC_MAX_DIGITS], b[ECC_MAX_DIGITS]; 994 u64 u[ECC_MAX_DIGITS], v[ECC_MAX_DIGITS]; 995 u64 carry; 996 int cmp_result; 997 998 if (vli_is_zero(input, ndigits)) { 999 vli_clear(result, ndigits); 1000 return; 1001 } 1002 1003 vli_set(a, input, ndigits); 1004 vli_set(b, mod, ndigits); 1005 vli_clear(u, ndigits); 1006 u[0] = 1; 1007 vli_clear(v, ndigits); 1008 1009 while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) { 1010 carry = 0; 1011 1012 if (EVEN(a)) { 1013 vli_rshift1(a, ndigits); 1014 1015 if (!EVEN(u)) 1016 carry = vli_add(u, u, mod, ndigits); 1017 1018 vli_rshift1(u, ndigits); 1019 if (carry) 1020 u[ndigits - 1] |= 0x8000000000000000ull; 1021 } else if (EVEN(b)) { 1022 vli_rshift1(b, ndigits); 1023 1024 if (!EVEN(v)) 1025 carry = vli_add(v, v, mod, ndigits); 1026 1027 vli_rshift1(v, ndigits); 1028 if (carry) 1029 v[ndigits - 1] |= 0x8000000000000000ull; 1030 } else if (cmp_result > 0) { 1031 vli_sub(a, a, b, ndigits); 1032 vli_rshift1(a, ndigits); 1033 1034 if (vli_cmp(u, v, ndigits) < 0) 1035 vli_add(u, u, mod, ndigits); 1036 1037 vli_sub(u, u, v, ndigits); 1038 if (!EVEN(u)) 1039 carry = vli_add(u, u, mod, ndigits); 1040 1041 vli_rshift1(u, ndigits); 1042 if (carry) 1043 u[ndigits - 1] |= 0x8000000000000000ull; 1044 } else { 1045 vli_sub(b, b, a, ndigits); 1046 vli_rshift1(b, ndigits); 1047 1048 if (vli_cmp(v, u, ndigits) < 0) 1049 vli_add(v, v, mod, ndigits); 1050 1051 vli_sub(v, v, u, ndigits); 1052 if (!EVEN(v)) 1053 carry = vli_add(v, v, mod, ndigits); 1054 1055 vli_rshift1(v, ndigits); 1056 if (carry) 1057 v[ndigits - 1] |= 0x8000000000000000ull; 1058 } 1059 } 1060 1061 vli_set(result, u, ndigits); 1062 } 1063 EXPORT_SYMBOL(vli_mod_inv); 1064 1065 /* ------ Point operations ------ */ 1066 1067 /* Returns true if p_point is the point at infinity, false otherwise. */ 1068 bool ecc_point_is_zero(const struct ecc_point *point) 1069 { 1070 return (vli_is_zero(point->x, point->ndigits) && 1071 vli_is_zero(point->y, point->ndigits)); 1072 } 1073 EXPORT_SYMBOL(ecc_point_is_zero); 1074 1075 /* Point multiplication algorithm using Montgomery's ladder with co-Z 1076 * coordinates. From https://eprint.iacr.org/2011/338.pdf 1077 */ 1078 1079 /* Double in place */ 1080 static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1, 1081 const struct ecc_curve *curve) 1082 { 1083 /* t1 = x, t2 = y, t3 = z */ 1084 u64 t4[ECC_MAX_DIGITS]; 1085 u64 t5[ECC_MAX_DIGITS]; 1086 const u64 *curve_prime = curve->p; 1087 const unsigned int ndigits = curve->g.ndigits; 1088 1089 if (vli_is_zero(z1, ndigits)) 1090 return; 1091 1092 /* t4 = y1^2 */ 1093 vli_mod_square_fast(t4, y1, curve); 1094 /* t5 = x1*y1^2 = A */ 1095 vli_mod_mult_fast(t5, x1, t4, curve); 1096 /* t4 = y1^4 */ 1097 vli_mod_square_fast(t4, t4, curve); 1098 /* t2 = y1*z1 = z3 */ 1099 vli_mod_mult_fast(y1, y1, z1, curve); 1100 /* t3 = z1^2 */ 1101 vli_mod_square_fast(z1, z1, curve); 1102 1103 /* t1 = x1 + z1^2 */ 1104 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 1105 /* t3 = 2*z1^2 */ 1106 vli_mod_add(z1, z1, z1, curve_prime, ndigits); 1107 /* t3 = x1 - z1^2 */ 1108 vli_mod_sub(z1, x1, z1, curve_prime, ndigits); 1109 /* t1 = x1^2 - z1^4 */ 1110 vli_mod_mult_fast(x1, x1, z1, curve); 1111 1112 /* t3 = 2*(x1^2 - z1^4) */ 1113 vli_mod_add(z1, x1, x1, curve_prime, ndigits); 1114 /* t1 = 3*(x1^2 - z1^4) */ 1115 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 1116 if (vli_test_bit(x1, 0)) { 1117 u64 carry = vli_add(x1, x1, curve_prime, ndigits); 1118 1119 vli_rshift1(x1, ndigits); 1120 x1[ndigits - 1] |= carry << 63; 1121 } else { 1122 vli_rshift1(x1, ndigits); 1123 } 1124 /* t1 = 3/2*(x1^2 - z1^4) = B */ 1125 1126 /* t3 = B^2 */ 1127 vli_mod_square_fast(z1, x1, curve); 1128 /* t3 = B^2 - A */ 1129 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 1130 /* t3 = B^2 - 2A = x3 */ 1131 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 1132 /* t5 = A - x3 */ 1133 vli_mod_sub(t5, t5, z1, curve_prime, ndigits); 1134 /* t1 = B * (A - x3) */ 1135 vli_mod_mult_fast(x1, x1, t5, curve); 1136 /* t4 = B * (A - x3) - y1^4 = y3 */ 1137 vli_mod_sub(t4, x1, t4, curve_prime, ndigits); 1138 1139 vli_set(x1, z1, ndigits); 1140 vli_set(z1, y1, ndigits); 1141 vli_set(y1, t4, ndigits); 1142 } 1143 1144 /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */ 1145 static void apply_z(u64 *x1, u64 *y1, u64 *z, const struct ecc_curve *curve) 1146 { 1147 u64 t1[ECC_MAX_DIGITS]; 1148 1149 vli_mod_square_fast(t1, z, curve); /* z^2 */ 1150 vli_mod_mult_fast(x1, x1, t1, curve); /* x1 * z^2 */ 1151 vli_mod_mult_fast(t1, t1, z, curve); /* z^3 */ 1152 vli_mod_mult_fast(y1, y1, t1, curve); /* y1 * z^3 */ 1153 } 1154 1155 /* P = (x1, y1) => 2P, (x2, y2) => P' */ 1156 static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1157 u64 *p_initial_z, const struct ecc_curve *curve) 1158 { 1159 u64 z[ECC_MAX_DIGITS]; 1160 const unsigned int ndigits = curve->g.ndigits; 1161 1162 vli_set(x2, x1, ndigits); 1163 vli_set(y2, y1, ndigits); 1164 1165 vli_clear(z, ndigits); 1166 z[0] = 1; 1167 1168 if (p_initial_z) 1169 vli_set(z, p_initial_z, ndigits); 1170 1171 apply_z(x1, y1, z, curve); 1172 1173 ecc_point_double_jacobian(x1, y1, z, curve); 1174 1175 apply_z(x2, y2, z, curve); 1176 } 1177 1178 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 1179 * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3) 1180 * or P => P', Q => P + Q 1181 */ 1182 static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1183 const struct ecc_curve *curve) 1184 { 1185 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 1186 u64 t5[ECC_MAX_DIGITS]; 1187 const u64 *curve_prime = curve->p; 1188 const unsigned int ndigits = curve->g.ndigits; 1189 1190 /* t5 = x2 - x1 */ 1191 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 1192 /* t5 = (x2 - x1)^2 = A */ 1193 vli_mod_square_fast(t5, t5, curve); 1194 /* t1 = x1*A = B */ 1195 vli_mod_mult_fast(x1, x1, t5, curve); 1196 /* t3 = x2*A = C */ 1197 vli_mod_mult_fast(x2, x2, t5, curve); 1198 /* t4 = y2 - y1 */ 1199 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1200 /* t5 = (y2 - y1)^2 = D */ 1201 vli_mod_square_fast(t5, y2, curve); 1202 1203 /* t5 = D - B */ 1204 vli_mod_sub(t5, t5, x1, curve_prime, ndigits); 1205 /* t5 = D - B - C = x3 */ 1206 vli_mod_sub(t5, t5, x2, curve_prime, ndigits); 1207 /* t3 = C - B */ 1208 vli_mod_sub(x2, x2, x1, curve_prime, ndigits); 1209 /* t2 = y1*(C - B) */ 1210 vli_mod_mult_fast(y1, y1, x2, curve); 1211 /* t3 = B - x3 */ 1212 vli_mod_sub(x2, x1, t5, curve_prime, ndigits); 1213 /* t4 = (y2 - y1)*(B - x3) */ 1214 vli_mod_mult_fast(y2, y2, x2, curve); 1215 /* t4 = y3 */ 1216 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1217 1218 vli_set(x2, t5, ndigits); 1219 } 1220 1221 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 1222 * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3) 1223 * or P => P - Q, Q => P + Q 1224 */ 1225 static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1226 const struct ecc_curve *curve) 1227 { 1228 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 1229 u64 t5[ECC_MAX_DIGITS]; 1230 u64 t6[ECC_MAX_DIGITS]; 1231 u64 t7[ECC_MAX_DIGITS]; 1232 const u64 *curve_prime = curve->p; 1233 const unsigned int ndigits = curve->g.ndigits; 1234 1235 /* t5 = x2 - x1 */ 1236 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 1237 /* t5 = (x2 - x1)^2 = A */ 1238 vli_mod_square_fast(t5, t5, curve); 1239 /* t1 = x1*A = B */ 1240 vli_mod_mult_fast(x1, x1, t5, curve); 1241 /* t3 = x2*A = C */ 1242 vli_mod_mult_fast(x2, x2, t5, curve); 1243 /* t4 = y2 + y1 */ 1244 vli_mod_add(t5, y2, y1, curve_prime, ndigits); 1245 /* t4 = y2 - y1 */ 1246 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1247 1248 /* t6 = C - B */ 1249 vli_mod_sub(t6, x2, x1, curve_prime, ndigits); 1250 /* t2 = y1 * (C - B) */ 1251 vli_mod_mult_fast(y1, y1, t6, curve); 1252 /* t6 = B + C */ 1253 vli_mod_add(t6, x1, x2, curve_prime, ndigits); 1254 /* t3 = (y2 - y1)^2 */ 1255 vli_mod_square_fast(x2, y2, curve); 1256 /* t3 = x3 */ 1257 vli_mod_sub(x2, x2, t6, curve_prime, ndigits); 1258 1259 /* t7 = B - x3 */ 1260 vli_mod_sub(t7, x1, x2, curve_prime, ndigits); 1261 /* t4 = (y2 - y1)*(B - x3) */ 1262 vli_mod_mult_fast(y2, y2, t7, curve); 1263 /* t4 = y3 */ 1264 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1265 1266 /* t7 = (y2 + y1)^2 = F */ 1267 vli_mod_square_fast(t7, t5, curve); 1268 /* t7 = x3' */ 1269 vli_mod_sub(t7, t7, t6, curve_prime, ndigits); 1270 /* t6 = x3' - B */ 1271 vli_mod_sub(t6, t7, x1, curve_prime, ndigits); 1272 /* t6 = (y2 + y1)*(x3' - B) */ 1273 vli_mod_mult_fast(t6, t6, t5, curve); 1274 /* t2 = y3' */ 1275 vli_mod_sub(y1, t6, y1, curve_prime, ndigits); 1276 1277 vli_set(x1, t7, ndigits); 1278 } 1279 1280 static void ecc_point_mult(struct ecc_point *result, 1281 const struct ecc_point *point, const u64 *scalar, 1282 u64 *initial_z, const struct ecc_curve *curve, 1283 unsigned int ndigits) 1284 { 1285 /* R0 and R1 */ 1286 u64 rx[2][ECC_MAX_DIGITS]; 1287 u64 ry[2][ECC_MAX_DIGITS]; 1288 u64 z[ECC_MAX_DIGITS]; 1289 u64 sk[2][ECC_MAX_DIGITS]; 1290 u64 *curve_prime = curve->p; 1291 int i, nb; 1292 int num_bits; 1293 int carry; 1294 1295 carry = vli_add(sk[0], scalar, curve->n, ndigits); 1296 vli_add(sk[1], sk[0], curve->n, ndigits); 1297 scalar = sk[!carry]; 1298 num_bits = sizeof(u64) * ndigits * 8 + 1; 1299 1300 vli_set(rx[1], point->x, ndigits); 1301 vli_set(ry[1], point->y, ndigits); 1302 1303 xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve); 1304 1305 for (i = num_bits - 2; i > 0; i--) { 1306 nb = !vli_test_bit(scalar, i); 1307 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve); 1308 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve); 1309 } 1310 1311 nb = !vli_test_bit(scalar, 0); 1312 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve); 1313 1314 /* Find final 1/Z value. */ 1315 /* X1 - X0 */ 1316 vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits); 1317 /* Yb * (X1 - X0) */ 1318 vli_mod_mult_fast(z, z, ry[1 - nb], curve); 1319 /* xP * Yb * (X1 - X0) */ 1320 vli_mod_mult_fast(z, z, point->x, curve); 1321 1322 /* 1 / (xP * Yb * (X1 - X0)) */ 1323 vli_mod_inv(z, z, curve_prime, point->ndigits); 1324 1325 /* yP / (xP * Yb * (X1 - X0)) */ 1326 vli_mod_mult_fast(z, z, point->y, curve); 1327 /* Xb * yP / (xP * Yb * (X1 - X0)) */ 1328 vli_mod_mult_fast(z, z, rx[1 - nb], curve); 1329 /* End 1/Z calculation */ 1330 1331 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve); 1332 1333 apply_z(rx[0], ry[0], z, curve); 1334 1335 vli_set(result->x, rx[0], ndigits); 1336 vli_set(result->y, ry[0], ndigits); 1337 } 1338 1339 /* Computes R = P + Q mod p */ 1340 static void ecc_point_add(const struct ecc_point *result, 1341 const struct ecc_point *p, const struct ecc_point *q, 1342 const struct ecc_curve *curve) 1343 { 1344 u64 z[ECC_MAX_DIGITS]; 1345 u64 px[ECC_MAX_DIGITS]; 1346 u64 py[ECC_MAX_DIGITS]; 1347 unsigned int ndigits = curve->g.ndigits; 1348 1349 vli_set(result->x, q->x, ndigits); 1350 vli_set(result->y, q->y, ndigits); 1351 vli_mod_sub(z, result->x, p->x, curve->p, ndigits); 1352 vli_set(px, p->x, ndigits); 1353 vli_set(py, p->y, ndigits); 1354 xycz_add(px, py, result->x, result->y, curve); 1355 vli_mod_inv(z, z, curve->p, ndigits); 1356 apply_z(result->x, result->y, z, curve); 1357 } 1358 1359 /* Computes R = u1P + u2Q mod p using Shamir's trick. 1360 * Based on: Kenneth MacKay's micro-ecc (2014). 1361 */ 1362 void ecc_point_mult_shamir(const struct ecc_point *result, 1363 const u64 *u1, const struct ecc_point *p, 1364 const u64 *u2, const struct ecc_point *q, 1365 const struct ecc_curve *curve) 1366 { 1367 u64 z[ECC_MAX_DIGITS]; 1368 u64 sump[2][ECC_MAX_DIGITS]; 1369 u64 *rx = result->x; 1370 u64 *ry = result->y; 1371 unsigned int ndigits = curve->g.ndigits; 1372 unsigned int num_bits; 1373 struct ecc_point sum = ECC_POINT_INIT(sump[0], sump[1], ndigits); 1374 const struct ecc_point *points[4]; 1375 const struct ecc_point *point; 1376 unsigned int idx; 1377 int i; 1378 1379 ecc_point_add(&sum, p, q, curve); 1380 points[0] = NULL; 1381 points[1] = p; 1382 points[2] = q; 1383 points[3] = ∑ 1384 1385 num_bits = max(vli_num_bits(u1, ndigits), vli_num_bits(u2, ndigits)); 1386 i = num_bits - 1; 1387 idx = (!!vli_test_bit(u1, i)) | ((!!vli_test_bit(u2, i)) << 1); 1388 point = points[idx]; 1389 1390 vli_set(rx, point->x, ndigits); 1391 vli_set(ry, point->y, ndigits); 1392 vli_clear(z + 1, ndigits - 1); 1393 z[0] = 1; 1394 1395 for (--i; i >= 0; i--) { 1396 ecc_point_double_jacobian(rx, ry, z, curve); 1397 idx = (!!vli_test_bit(u1, i)) | ((!!vli_test_bit(u2, i)) << 1); 1398 point = points[idx]; 1399 if (point) { 1400 u64 tx[ECC_MAX_DIGITS]; 1401 u64 ty[ECC_MAX_DIGITS]; 1402 u64 tz[ECC_MAX_DIGITS]; 1403 1404 vli_set(tx, point->x, ndigits); 1405 vli_set(ty, point->y, ndigits); 1406 apply_z(tx, ty, z, curve); 1407 vli_mod_sub(tz, rx, tx, curve->p, ndigits); 1408 xycz_add(tx, ty, rx, ry, curve); 1409 vli_mod_mult_fast(z, z, tz, curve); 1410 } 1411 } 1412 vli_mod_inv(z, z, curve->p, ndigits); 1413 apply_z(rx, ry, z, curve); 1414 } 1415 EXPORT_SYMBOL(ecc_point_mult_shamir); 1416 1417 static int __ecc_is_key_valid(const struct ecc_curve *curve, 1418 const u64 *private_key, unsigned int ndigits) 1419 { 1420 u64 one[ECC_MAX_DIGITS] = { 1, }; 1421 u64 res[ECC_MAX_DIGITS]; 1422 1423 if (!private_key) 1424 return -EINVAL; 1425 1426 if (curve->g.ndigits != ndigits) 1427 return -EINVAL; 1428 1429 /* Make sure the private key is in the range [2, n-3]. */ 1430 if (vli_cmp(one, private_key, ndigits) != -1) 1431 return -EINVAL; 1432 vli_sub(res, curve->n, one, ndigits); 1433 vli_sub(res, res, one, ndigits); 1434 if (vli_cmp(res, private_key, ndigits) != 1) 1435 return -EINVAL; 1436 1437 return 0; 1438 } 1439 1440 int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits, 1441 const u64 *private_key, unsigned int private_key_len) 1442 { 1443 int nbytes; 1444 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1445 1446 nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1447 1448 if (private_key_len != nbytes) 1449 return -EINVAL; 1450 1451 return __ecc_is_key_valid(curve, private_key, ndigits); 1452 } 1453 EXPORT_SYMBOL(ecc_is_key_valid); 1454 1455 /* 1456 * ECC private keys are generated using the method of extra random bits, 1457 * equivalent to that described in FIPS 186-4, Appendix B.4.1. 1458 * 1459 * d = (c mod(n–1)) + 1 where c is a string of random bits, 64 bits longer 1460 * than requested 1461 * 0 <= c mod(n-1) <= n-2 and implies that 1462 * 1 <= d <= n-1 1463 * 1464 * This method generates a private key uniformly distributed in the range 1465 * [1, n-1]. 1466 */ 1467 int ecc_gen_privkey(unsigned int curve_id, unsigned int ndigits, u64 *privkey) 1468 { 1469 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1470 u64 priv[ECC_MAX_DIGITS]; 1471 unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1472 unsigned int nbits = vli_num_bits(curve->n, ndigits); 1473 int err; 1474 1475 /* Check that N is included in Table 1 of FIPS 186-4, section 6.1.1 */ 1476 if (nbits < 160 || ndigits > ARRAY_SIZE(priv)) 1477 return -EINVAL; 1478 1479 /* 1480 * FIPS 186-4 recommends that the private key should be obtained from a 1481 * RBG with a security strength equal to or greater than the security 1482 * strength associated with N. 1483 * 1484 * The maximum security strength identified by NIST SP800-57pt1r4 for 1485 * ECC is 256 (N >= 512). 1486 * 1487 * This condition is met by the default RNG because it selects a favored 1488 * DRBG with a security strength of 256. 1489 */ 1490 if (crypto_get_default_rng()) 1491 return -EFAULT; 1492 1493 err = crypto_rng_get_bytes(crypto_default_rng, (u8 *)priv, nbytes); 1494 crypto_put_default_rng(); 1495 if (err) 1496 return err; 1497 1498 /* Make sure the private key is in the valid range. */ 1499 if (__ecc_is_key_valid(curve, priv, ndigits)) 1500 return -EINVAL; 1501 1502 ecc_swap_digits(priv, privkey, ndigits); 1503 1504 return 0; 1505 } 1506 EXPORT_SYMBOL(ecc_gen_privkey); 1507 1508 int ecc_make_pub_key(unsigned int curve_id, unsigned int ndigits, 1509 const u64 *private_key, u64 *public_key) 1510 { 1511 int ret = 0; 1512 struct ecc_point *pk; 1513 u64 priv[ECC_MAX_DIGITS]; 1514 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1515 1516 if (!private_key || !curve || ndigits > ARRAY_SIZE(priv)) { 1517 ret = -EINVAL; 1518 goto out; 1519 } 1520 1521 ecc_swap_digits(private_key, priv, ndigits); 1522 1523 pk = ecc_alloc_point(ndigits); 1524 if (!pk) { 1525 ret = -ENOMEM; 1526 goto out; 1527 } 1528 1529 ecc_point_mult(pk, &curve->g, priv, NULL, curve, ndigits); 1530 1531 /* SP800-56A rev 3 5.6.2.1.3 key check */ 1532 if (ecc_is_pubkey_valid_full(curve, pk)) { 1533 ret = -EAGAIN; 1534 goto err_free_point; 1535 } 1536 1537 ecc_swap_digits(pk->x, public_key, ndigits); 1538 ecc_swap_digits(pk->y, &public_key[ndigits], ndigits); 1539 1540 err_free_point: 1541 ecc_free_point(pk); 1542 out: 1543 return ret; 1544 } 1545 EXPORT_SYMBOL(ecc_make_pub_key); 1546 1547 /* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */ 1548 int ecc_is_pubkey_valid_partial(const struct ecc_curve *curve, 1549 struct ecc_point *pk) 1550 { 1551 u64 yy[ECC_MAX_DIGITS], xxx[ECC_MAX_DIGITS], w[ECC_MAX_DIGITS]; 1552 1553 if (WARN_ON(pk->ndigits != curve->g.ndigits)) 1554 return -EINVAL; 1555 1556 /* Check 1: Verify key is not the zero point. */ 1557 if (ecc_point_is_zero(pk)) 1558 return -EINVAL; 1559 1560 /* Check 2: Verify key is in the range [1, p-1]. */ 1561 if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1) 1562 return -EINVAL; 1563 if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1) 1564 return -EINVAL; 1565 1566 /* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */ 1567 vli_mod_square_fast(yy, pk->y, curve); /* y^2 */ 1568 vli_mod_square_fast(xxx, pk->x, curve); /* x^2 */ 1569 vli_mod_mult_fast(xxx, xxx, pk->x, curve); /* x^3 */ 1570 vli_mod_mult_fast(w, curve->a, pk->x, curve); /* a·x */ 1571 vli_mod_add(w, w, curve->b, curve->p, pk->ndigits); /* a·x + b */ 1572 vli_mod_add(w, w, xxx, curve->p, pk->ndigits); /* x^3 + a·x + b */ 1573 if (vli_cmp(yy, w, pk->ndigits) != 0) /* Equation */ 1574 return -EINVAL; 1575 1576 return 0; 1577 } 1578 EXPORT_SYMBOL(ecc_is_pubkey_valid_partial); 1579 1580 /* SP800-56A section 5.6.2.3.3 full verification */ 1581 int ecc_is_pubkey_valid_full(const struct ecc_curve *curve, 1582 struct ecc_point *pk) 1583 { 1584 struct ecc_point *nQ; 1585 1586 /* Checks 1 through 3 */ 1587 int ret = ecc_is_pubkey_valid_partial(curve, pk); 1588 1589 if (ret) 1590 return ret; 1591 1592 /* Check 4: Verify that nQ is the zero point. */ 1593 nQ = ecc_alloc_point(pk->ndigits); 1594 if (!nQ) 1595 return -ENOMEM; 1596 1597 ecc_point_mult(nQ, pk, curve->n, NULL, curve, pk->ndigits); 1598 if (!ecc_point_is_zero(nQ)) 1599 ret = -EINVAL; 1600 1601 ecc_free_point(nQ); 1602 1603 return ret; 1604 } 1605 EXPORT_SYMBOL(ecc_is_pubkey_valid_full); 1606 1607 int crypto_ecdh_shared_secret(unsigned int curve_id, unsigned int ndigits, 1608 const u64 *private_key, const u64 *public_key, 1609 u64 *secret) 1610 { 1611 int ret = 0; 1612 struct ecc_point *product, *pk; 1613 u64 priv[ECC_MAX_DIGITS]; 1614 u64 rand_z[ECC_MAX_DIGITS]; 1615 unsigned int nbytes; 1616 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1617 1618 if (!private_key || !public_key || !curve || 1619 ndigits > ARRAY_SIZE(priv) || ndigits > ARRAY_SIZE(rand_z)) { 1620 ret = -EINVAL; 1621 goto out; 1622 } 1623 1624 nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1625 1626 get_random_bytes(rand_z, nbytes); 1627 1628 pk = ecc_alloc_point(ndigits); 1629 if (!pk) { 1630 ret = -ENOMEM; 1631 goto out; 1632 } 1633 1634 ecc_swap_digits(public_key, pk->x, ndigits); 1635 ecc_swap_digits(&public_key[ndigits], pk->y, ndigits); 1636 ret = ecc_is_pubkey_valid_partial(curve, pk); 1637 if (ret) 1638 goto err_alloc_product; 1639 1640 ecc_swap_digits(private_key, priv, ndigits); 1641 1642 product = ecc_alloc_point(ndigits); 1643 if (!product) { 1644 ret = -ENOMEM; 1645 goto err_alloc_product; 1646 } 1647 1648 ecc_point_mult(product, pk, priv, rand_z, curve, ndigits); 1649 1650 if (ecc_point_is_zero(product)) { 1651 ret = -EFAULT; 1652 goto err_validity; 1653 } 1654 1655 ecc_swap_digits(product->x, secret, ndigits); 1656 1657 err_validity: 1658 memzero_explicit(priv, sizeof(priv)); 1659 memzero_explicit(rand_z, sizeof(rand_z)); 1660 ecc_free_point(product); 1661 err_alloc_product: 1662 ecc_free_point(pk); 1663 out: 1664 return ret; 1665 } 1666 EXPORT_SYMBOL(crypto_ecdh_shared_secret); 1667 1668 MODULE_LICENSE("Dual BSD/GPL"); 1669