1 /* 2 * Copyright (c) 2013, Kenneth MacKay 3 * All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are 7 * met: 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 18 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 19 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 20 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 */ 26 27 #include <linux/random.h> 28 #include <linux/slab.h> 29 #include <linux/swab.h> 30 #include <linux/fips.h> 31 #include <crypto/ecdh.h> 32 #include <crypto/rng.h> 33 34 #include "ecc.h" 35 #include "ecc_curve_defs.h" 36 37 typedef struct { 38 u64 m_low; 39 u64 m_high; 40 } uint128_t; 41 42 static inline const struct ecc_curve *ecc_get_curve(unsigned int curve_id) 43 { 44 switch (curve_id) { 45 /* In FIPS mode only allow P256 and higher */ 46 case ECC_CURVE_NIST_P192: 47 return fips_enabled ? NULL : &nist_p192; 48 case ECC_CURVE_NIST_P256: 49 return &nist_p256; 50 default: 51 return NULL; 52 } 53 } 54 55 static u64 *ecc_alloc_digits_space(unsigned int ndigits) 56 { 57 size_t len = ndigits * sizeof(u64); 58 59 if (!len) 60 return NULL; 61 62 return kmalloc(len, GFP_KERNEL); 63 } 64 65 static void ecc_free_digits_space(u64 *space) 66 { 67 kzfree(space); 68 } 69 70 static struct ecc_point *ecc_alloc_point(unsigned int ndigits) 71 { 72 struct ecc_point *p = kmalloc(sizeof(*p), GFP_KERNEL); 73 74 if (!p) 75 return NULL; 76 77 p->x = ecc_alloc_digits_space(ndigits); 78 if (!p->x) 79 goto err_alloc_x; 80 81 p->y = ecc_alloc_digits_space(ndigits); 82 if (!p->y) 83 goto err_alloc_y; 84 85 p->ndigits = ndigits; 86 87 return p; 88 89 err_alloc_y: 90 ecc_free_digits_space(p->x); 91 err_alloc_x: 92 kfree(p); 93 return NULL; 94 } 95 96 static void ecc_free_point(struct ecc_point *p) 97 { 98 if (!p) 99 return; 100 101 kzfree(p->x); 102 kzfree(p->y); 103 kzfree(p); 104 } 105 106 static void vli_clear(u64 *vli, unsigned int ndigits) 107 { 108 int i; 109 110 for (i = 0; i < ndigits; i++) 111 vli[i] = 0; 112 } 113 114 /* Returns true if vli == 0, false otherwise. */ 115 static bool vli_is_zero(const u64 *vli, unsigned int ndigits) 116 { 117 int i; 118 119 for (i = 0; i < ndigits; i++) { 120 if (vli[i]) 121 return false; 122 } 123 124 return true; 125 } 126 127 /* Returns nonzero if bit bit of vli is set. */ 128 static u64 vli_test_bit(const u64 *vli, unsigned int bit) 129 { 130 return (vli[bit / 64] & ((u64)1 << (bit % 64))); 131 } 132 133 /* Counts the number of 64-bit "digits" in vli. */ 134 static unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits) 135 { 136 int i; 137 138 /* Search from the end until we find a non-zero digit. 139 * We do it in reverse because we expect that most digits will 140 * be nonzero. 141 */ 142 for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--); 143 144 return (i + 1); 145 } 146 147 /* Counts the number of bits required for vli. */ 148 static unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits) 149 { 150 unsigned int i, num_digits; 151 u64 digit; 152 153 num_digits = vli_num_digits(vli, ndigits); 154 if (num_digits == 0) 155 return 0; 156 157 digit = vli[num_digits - 1]; 158 for (i = 0; digit; i++) 159 digit >>= 1; 160 161 return ((num_digits - 1) * 64 + i); 162 } 163 164 /* Sets dest = src. */ 165 static void vli_set(u64 *dest, const u64 *src, unsigned int ndigits) 166 { 167 int i; 168 169 for (i = 0; i < ndigits; i++) 170 dest[i] = src[i]; 171 } 172 173 /* Returns sign of left - right. */ 174 static int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits) 175 { 176 int i; 177 178 for (i = ndigits - 1; i >= 0; i--) { 179 if (left[i] > right[i]) 180 return 1; 181 else if (left[i] < right[i]) 182 return -1; 183 } 184 185 return 0; 186 } 187 188 /* Computes result = in << c, returning carry. Can modify in place 189 * (if result == in). 0 < shift < 64. 190 */ 191 static u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift, 192 unsigned int ndigits) 193 { 194 u64 carry = 0; 195 int i; 196 197 for (i = 0; i < ndigits; i++) { 198 u64 temp = in[i]; 199 200 result[i] = (temp << shift) | carry; 201 carry = temp >> (64 - shift); 202 } 203 204 return carry; 205 } 206 207 /* Computes vli = vli >> 1. */ 208 static void vli_rshift1(u64 *vli, unsigned int ndigits) 209 { 210 u64 *end = vli; 211 u64 carry = 0; 212 213 vli += ndigits; 214 215 while (vli-- > end) { 216 u64 temp = *vli; 217 *vli = (temp >> 1) | carry; 218 carry = temp << 63; 219 } 220 } 221 222 /* Computes result = left + right, returning carry. Can modify in place. */ 223 static u64 vli_add(u64 *result, const u64 *left, const u64 *right, 224 unsigned int ndigits) 225 { 226 u64 carry = 0; 227 int i; 228 229 for (i = 0; i < ndigits; i++) { 230 u64 sum; 231 232 sum = left[i] + right[i] + carry; 233 if (sum != left[i]) 234 carry = (sum < left[i]); 235 236 result[i] = sum; 237 } 238 239 return carry; 240 } 241 242 /* Computes result = left - right, returning borrow. Can modify in place. */ 243 static u64 vli_sub(u64 *result, const u64 *left, const u64 *right, 244 unsigned int ndigits) 245 { 246 u64 borrow = 0; 247 int i; 248 249 for (i = 0; i < ndigits; i++) { 250 u64 diff; 251 252 diff = left[i] - right[i] - borrow; 253 if (diff != left[i]) 254 borrow = (diff > left[i]); 255 256 result[i] = diff; 257 } 258 259 return borrow; 260 } 261 262 static uint128_t mul_64_64(u64 left, u64 right) 263 { 264 u64 a0 = left & 0xffffffffull; 265 u64 a1 = left >> 32; 266 u64 b0 = right & 0xffffffffull; 267 u64 b1 = right >> 32; 268 u64 m0 = a0 * b0; 269 u64 m1 = a0 * b1; 270 u64 m2 = a1 * b0; 271 u64 m3 = a1 * b1; 272 uint128_t result; 273 274 m2 += (m0 >> 32); 275 m2 += m1; 276 277 /* Overflow */ 278 if (m2 < m1) 279 m3 += 0x100000000ull; 280 281 result.m_low = (m0 & 0xffffffffull) | (m2 << 32); 282 result.m_high = m3 + (m2 >> 32); 283 284 return result; 285 } 286 287 static uint128_t add_128_128(uint128_t a, uint128_t b) 288 { 289 uint128_t result; 290 291 result.m_low = a.m_low + b.m_low; 292 result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low); 293 294 return result; 295 } 296 297 static void vli_mult(u64 *result, const u64 *left, const u64 *right, 298 unsigned int ndigits) 299 { 300 uint128_t r01 = { 0, 0 }; 301 u64 r2 = 0; 302 unsigned int i, k; 303 304 /* Compute each digit of result in sequence, maintaining the 305 * carries. 306 */ 307 for (k = 0; k < ndigits * 2 - 1; k++) { 308 unsigned int min; 309 310 if (k < ndigits) 311 min = 0; 312 else 313 min = (k + 1) - ndigits; 314 315 for (i = min; i <= k && i < ndigits; i++) { 316 uint128_t product; 317 318 product = mul_64_64(left[i], right[k - i]); 319 320 r01 = add_128_128(r01, product); 321 r2 += (r01.m_high < product.m_high); 322 } 323 324 result[k] = r01.m_low; 325 r01.m_low = r01.m_high; 326 r01.m_high = r2; 327 r2 = 0; 328 } 329 330 result[ndigits * 2 - 1] = r01.m_low; 331 } 332 333 static void vli_square(u64 *result, const u64 *left, unsigned int ndigits) 334 { 335 uint128_t r01 = { 0, 0 }; 336 u64 r2 = 0; 337 int i, k; 338 339 for (k = 0; k < ndigits * 2 - 1; k++) { 340 unsigned int min; 341 342 if (k < ndigits) 343 min = 0; 344 else 345 min = (k + 1) - ndigits; 346 347 for (i = min; i <= k && i <= k - i; i++) { 348 uint128_t product; 349 350 product = mul_64_64(left[i], left[k - i]); 351 352 if (i < k - i) { 353 r2 += product.m_high >> 63; 354 product.m_high = (product.m_high << 1) | 355 (product.m_low >> 63); 356 product.m_low <<= 1; 357 } 358 359 r01 = add_128_128(r01, product); 360 r2 += (r01.m_high < product.m_high); 361 } 362 363 result[k] = r01.m_low; 364 r01.m_low = r01.m_high; 365 r01.m_high = r2; 366 r2 = 0; 367 } 368 369 result[ndigits * 2 - 1] = r01.m_low; 370 } 371 372 /* Computes result = (left + right) % mod. 373 * Assumes that left < mod and right < mod, result != mod. 374 */ 375 static void vli_mod_add(u64 *result, const u64 *left, const u64 *right, 376 const u64 *mod, unsigned int ndigits) 377 { 378 u64 carry; 379 380 carry = vli_add(result, left, right, ndigits); 381 382 /* result > mod (result = mod + remainder), so subtract mod to 383 * get remainder. 384 */ 385 if (carry || vli_cmp(result, mod, ndigits) >= 0) 386 vli_sub(result, result, mod, ndigits); 387 } 388 389 /* Computes result = (left - right) % mod. 390 * Assumes that left < mod and right < mod, result != mod. 391 */ 392 static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right, 393 const u64 *mod, unsigned int ndigits) 394 { 395 u64 borrow = vli_sub(result, left, right, ndigits); 396 397 /* In this case, p_result == -diff == (max int) - diff. 398 * Since -x % d == d - x, we can get the correct result from 399 * result + mod (with overflow). 400 */ 401 if (borrow) 402 vli_add(result, result, mod, ndigits); 403 } 404 405 /* Computes p_result = p_product % curve_p. 406 * See algorithm 5 and 6 from 407 * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf 408 */ 409 static void vli_mmod_fast_192(u64 *result, const u64 *product, 410 const u64 *curve_prime, u64 *tmp) 411 { 412 const unsigned int ndigits = 3; 413 int carry; 414 415 vli_set(result, product, ndigits); 416 417 vli_set(tmp, &product[3], ndigits); 418 carry = vli_add(result, result, tmp, ndigits); 419 420 tmp[0] = 0; 421 tmp[1] = product[3]; 422 tmp[2] = product[4]; 423 carry += vli_add(result, result, tmp, ndigits); 424 425 tmp[0] = tmp[1] = product[5]; 426 tmp[2] = 0; 427 carry += vli_add(result, result, tmp, ndigits); 428 429 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 430 carry -= vli_sub(result, result, curve_prime, ndigits); 431 } 432 433 /* Computes result = product % curve_prime 434 * from http://www.nsa.gov/ia/_files/nist-routines.pdf 435 */ 436 static void vli_mmod_fast_256(u64 *result, const u64 *product, 437 const u64 *curve_prime, u64 *tmp) 438 { 439 int carry; 440 const unsigned int ndigits = 4; 441 442 /* t */ 443 vli_set(result, product, ndigits); 444 445 /* s1 */ 446 tmp[0] = 0; 447 tmp[1] = product[5] & 0xffffffff00000000ull; 448 tmp[2] = product[6]; 449 tmp[3] = product[7]; 450 carry = vli_lshift(tmp, tmp, 1, ndigits); 451 carry += vli_add(result, result, tmp, ndigits); 452 453 /* s2 */ 454 tmp[1] = product[6] << 32; 455 tmp[2] = (product[6] >> 32) | (product[7] << 32); 456 tmp[3] = product[7] >> 32; 457 carry += vli_lshift(tmp, tmp, 1, ndigits); 458 carry += vli_add(result, result, tmp, ndigits); 459 460 /* s3 */ 461 tmp[0] = product[4]; 462 tmp[1] = product[5] & 0xffffffff; 463 tmp[2] = 0; 464 tmp[3] = product[7]; 465 carry += vli_add(result, result, tmp, ndigits); 466 467 /* s4 */ 468 tmp[0] = (product[4] >> 32) | (product[5] << 32); 469 tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull); 470 tmp[2] = product[7]; 471 tmp[3] = (product[6] >> 32) | (product[4] << 32); 472 carry += vli_add(result, result, tmp, ndigits); 473 474 /* d1 */ 475 tmp[0] = (product[5] >> 32) | (product[6] << 32); 476 tmp[1] = (product[6] >> 32); 477 tmp[2] = 0; 478 tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32); 479 carry -= vli_sub(result, result, tmp, ndigits); 480 481 /* d2 */ 482 tmp[0] = product[6]; 483 tmp[1] = product[7]; 484 tmp[2] = 0; 485 tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull); 486 carry -= vli_sub(result, result, tmp, ndigits); 487 488 /* d3 */ 489 tmp[0] = (product[6] >> 32) | (product[7] << 32); 490 tmp[1] = (product[7] >> 32) | (product[4] << 32); 491 tmp[2] = (product[4] >> 32) | (product[5] << 32); 492 tmp[3] = (product[6] << 32); 493 carry -= vli_sub(result, result, tmp, ndigits); 494 495 /* d4 */ 496 tmp[0] = product[7]; 497 tmp[1] = product[4] & 0xffffffff00000000ull; 498 tmp[2] = product[5]; 499 tmp[3] = product[6] & 0xffffffff00000000ull; 500 carry -= vli_sub(result, result, tmp, ndigits); 501 502 if (carry < 0) { 503 do { 504 carry += vli_add(result, result, curve_prime, ndigits); 505 } while (carry < 0); 506 } else { 507 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 508 carry -= vli_sub(result, result, curve_prime, ndigits); 509 } 510 } 511 512 /* Computes result = product % curve_prime 513 * from http://www.nsa.gov/ia/_files/nist-routines.pdf 514 */ 515 static bool vli_mmod_fast(u64 *result, u64 *product, 516 const u64 *curve_prime, unsigned int ndigits) 517 { 518 u64 tmp[2 * ndigits]; 519 520 switch (ndigits) { 521 case 3: 522 vli_mmod_fast_192(result, product, curve_prime, tmp); 523 break; 524 case 4: 525 vli_mmod_fast_256(result, product, curve_prime, tmp); 526 break; 527 default: 528 pr_err("unsupports digits size!\n"); 529 return false; 530 } 531 532 return true; 533 } 534 535 /* Computes result = (left * right) % curve_prime. */ 536 static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right, 537 const u64 *curve_prime, unsigned int ndigits) 538 { 539 u64 product[2 * ndigits]; 540 541 vli_mult(product, left, right, ndigits); 542 vli_mmod_fast(result, product, curve_prime, ndigits); 543 } 544 545 /* Computes result = left^2 % curve_prime. */ 546 static void vli_mod_square_fast(u64 *result, const u64 *left, 547 const u64 *curve_prime, unsigned int ndigits) 548 { 549 u64 product[2 * ndigits]; 550 551 vli_square(product, left, ndigits); 552 vli_mmod_fast(result, product, curve_prime, ndigits); 553 } 554 555 #define EVEN(vli) (!(vli[0] & 1)) 556 /* Computes result = (1 / p_input) % mod. All VLIs are the same size. 557 * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide" 558 * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf 559 */ 560 static void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod, 561 unsigned int ndigits) 562 { 563 u64 a[ndigits], b[ndigits]; 564 u64 u[ndigits], v[ndigits]; 565 u64 carry; 566 int cmp_result; 567 568 if (vli_is_zero(input, ndigits)) { 569 vli_clear(result, ndigits); 570 return; 571 } 572 573 vli_set(a, input, ndigits); 574 vli_set(b, mod, ndigits); 575 vli_clear(u, ndigits); 576 u[0] = 1; 577 vli_clear(v, ndigits); 578 579 while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) { 580 carry = 0; 581 582 if (EVEN(a)) { 583 vli_rshift1(a, ndigits); 584 585 if (!EVEN(u)) 586 carry = vli_add(u, u, mod, ndigits); 587 588 vli_rshift1(u, ndigits); 589 if (carry) 590 u[ndigits - 1] |= 0x8000000000000000ull; 591 } else if (EVEN(b)) { 592 vli_rshift1(b, ndigits); 593 594 if (!EVEN(v)) 595 carry = vli_add(v, v, mod, ndigits); 596 597 vli_rshift1(v, ndigits); 598 if (carry) 599 v[ndigits - 1] |= 0x8000000000000000ull; 600 } else if (cmp_result > 0) { 601 vli_sub(a, a, b, ndigits); 602 vli_rshift1(a, ndigits); 603 604 if (vli_cmp(u, v, ndigits) < 0) 605 vli_add(u, u, mod, ndigits); 606 607 vli_sub(u, u, v, ndigits); 608 if (!EVEN(u)) 609 carry = vli_add(u, u, mod, ndigits); 610 611 vli_rshift1(u, ndigits); 612 if (carry) 613 u[ndigits - 1] |= 0x8000000000000000ull; 614 } else { 615 vli_sub(b, b, a, ndigits); 616 vli_rshift1(b, ndigits); 617 618 if (vli_cmp(v, u, ndigits) < 0) 619 vli_add(v, v, mod, ndigits); 620 621 vli_sub(v, v, u, ndigits); 622 if (!EVEN(v)) 623 carry = vli_add(v, v, mod, ndigits); 624 625 vli_rshift1(v, ndigits); 626 if (carry) 627 v[ndigits - 1] |= 0x8000000000000000ull; 628 } 629 } 630 631 vli_set(result, u, ndigits); 632 } 633 634 /* ------ Point operations ------ */ 635 636 /* Returns true if p_point is the point at infinity, false otherwise. */ 637 static bool ecc_point_is_zero(const struct ecc_point *point) 638 { 639 return (vli_is_zero(point->x, point->ndigits) && 640 vli_is_zero(point->y, point->ndigits)); 641 } 642 643 /* Point multiplication algorithm using Montgomery's ladder with co-Z 644 * coordinates. From http://eprint.iacr.org/2011/338.pdf 645 */ 646 647 /* Double in place */ 648 static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1, 649 u64 *curve_prime, unsigned int ndigits) 650 { 651 /* t1 = x, t2 = y, t3 = z */ 652 u64 t4[ndigits]; 653 u64 t5[ndigits]; 654 655 if (vli_is_zero(z1, ndigits)) 656 return; 657 658 /* t4 = y1^2 */ 659 vli_mod_square_fast(t4, y1, curve_prime, ndigits); 660 /* t5 = x1*y1^2 = A */ 661 vli_mod_mult_fast(t5, x1, t4, curve_prime, ndigits); 662 /* t4 = y1^4 */ 663 vli_mod_square_fast(t4, t4, curve_prime, ndigits); 664 /* t2 = y1*z1 = z3 */ 665 vli_mod_mult_fast(y1, y1, z1, curve_prime, ndigits); 666 /* t3 = z1^2 */ 667 vli_mod_square_fast(z1, z1, curve_prime, ndigits); 668 669 /* t1 = x1 + z1^2 */ 670 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 671 /* t3 = 2*z1^2 */ 672 vli_mod_add(z1, z1, z1, curve_prime, ndigits); 673 /* t3 = x1 - z1^2 */ 674 vli_mod_sub(z1, x1, z1, curve_prime, ndigits); 675 /* t1 = x1^2 - z1^4 */ 676 vli_mod_mult_fast(x1, x1, z1, curve_prime, ndigits); 677 678 /* t3 = 2*(x1^2 - z1^4) */ 679 vli_mod_add(z1, x1, x1, curve_prime, ndigits); 680 /* t1 = 3*(x1^2 - z1^4) */ 681 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 682 if (vli_test_bit(x1, 0)) { 683 u64 carry = vli_add(x1, x1, curve_prime, ndigits); 684 685 vli_rshift1(x1, ndigits); 686 x1[ndigits - 1] |= carry << 63; 687 } else { 688 vli_rshift1(x1, ndigits); 689 } 690 /* t1 = 3/2*(x1^2 - z1^4) = B */ 691 692 /* t3 = B^2 */ 693 vli_mod_square_fast(z1, x1, curve_prime, ndigits); 694 /* t3 = B^2 - A */ 695 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 696 /* t3 = B^2 - 2A = x3 */ 697 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 698 /* t5 = A - x3 */ 699 vli_mod_sub(t5, t5, z1, curve_prime, ndigits); 700 /* t1 = B * (A - x3) */ 701 vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits); 702 /* t4 = B * (A - x3) - y1^4 = y3 */ 703 vli_mod_sub(t4, x1, t4, curve_prime, ndigits); 704 705 vli_set(x1, z1, ndigits); 706 vli_set(z1, y1, ndigits); 707 vli_set(y1, t4, ndigits); 708 } 709 710 /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */ 711 static void apply_z(u64 *x1, u64 *y1, u64 *z, u64 *curve_prime, 712 unsigned int ndigits) 713 { 714 u64 t1[ndigits]; 715 716 vli_mod_square_fast(t1, z, curve_prime, ndigits); /* z^2 */ 717 vli_mod_mult_fast(x1, x1, t1, curve_prime, ndigits); /* x1 * z^2 */ 718 vli_mod_mult_fast(t1, t1, z, curve_prime, ndigits); /* z^3 */ 719 vli_mod_mult_fast(y1, y1, t1, curve_prime, ndigits); /* y1 * z^3 */ 720 } 721 722 /* P = (x1, y1) => 2P, (x2, y2) => P' */ 723 static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 724 u64 *p_initial_z, u64 *curve_prime, 725 unsigned int ndigits) 726 { 727 u64 z[ndigits]; 728 729 vli_set(x2, x1, ndigits); 730 vli_set(y2, y1, ndigits); 731 732 vli_clear(z, ndigits); 733 z[0] = 1; 734 735 if (p_initial_z) 736 vli_set(z, p_initial_z, ndigits); 737 738 apply_z(x1, y1, z, curve_prime, ndigits); 739 740 ecc_point_double_jacobian(x1, y1, z, curve_prime, ndigits); 741 742 apply_z(x2, y2, z, curve_prime, ndigits); 743 } 744 745 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 746 * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3) 747 * or P => P', Q => P + Q 748 */ 749 static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime, 750 unsigned int ndigits) 751 { 752 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 753 u64 t5[ndigits]; 754 755 /* t5 = x2 - x1 */ 756 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 757 /* t5 = (x2 - x1)^2 = A */ 758 vli_mod_square_fast(t5, t5, curve_prime, ndigits); 759 /* t1 = x1*A = B */ 760 vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits); 761 /* t3 = x2*A = C */ 762 vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits); 763 /* t4 = y2 - y1 */ 764 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 765 /* t5 = (y2 - y1)^2 = D */ 766 vli_mod_square_fast(t5, y2, curve_prime, ndigits); 767 768 /* t5 = D - B */ 769 vli_mod_sub(t5, t5, x1, curve_prime, ndigits); 770 /* t5 = D - B - C = x3 */ 771 vli_mod_sub(t5, t5, x2, curve_prime, ndigits); 772 /* t3 = C - B */ 773 vli_mod_sub(x2, x2, x1, curve_prime, ndigits); 774 /* t2 = y1*(C - B) */ 775 vli_mod_mult_fast(y1, y1, x2, curve_prime, ndigits); 776 /* t3 = B - x3 */ 777 vli_mod_sub(x2, x1, t5, curve_prime, ndigits); 778 /* t4 = (y2 - y1)*(B - x3) */ 779 vli_mod_mult_fast(y2, y2, x2, curve_prime, ndigits); 780 /* t4 = y3 */ 781 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 782 783 vli_set(x2, t5, ndigits); 784 } 785 786 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 787 * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3) 788 * or P => P - Q, Q => P + Q 789 */ 790 static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime, 791 unsigned int ndigits) 792 { 793 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 794 u64 t5[ndigits]; 795 u64 t6[ndigits]; 796 u64 t7[ndigits]; 797 798 /* t5 = x2 - x1 */ 799 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 800 /* t5 = (x2 - x1)^2 = A */ 801 vli_mod_square_fast(t5, t5, curve_prime, ndigits); 802 /* t1 = x1*A = B */ 803 vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits); 804 /* t3 = x2*A = C */ 805 vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits); 806 /* t4 = y2 + y1 */ 807 vli_mod_add(t5, y2, y1, curve_prime, ndigits); 808 /* t4 = y2 - y1 */ 809 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 810 811 /* t6 = C - B */ 812 vli_mod_sub(t6, x2, x1, curve_prime, ndigits); 813 /* t2 = y1 * (C - B) */ 814 vli_mod_mult_fast(y1, y1, t6, curve_prime, ndigits); 815 /* t6 = B + C */ 816 vli_mod_add(t6, x1, x2, curve_prime, ndigits); 817 /* t3 = (y2 - y1)^2 */ 818 vli_mod_square_fast(x2, y2, curve_prime, ndigits); 819 /* t3 = x3 */ 820 vli_mod_sub(x2, x2, t6, curve_prime, ndigits); 821 822 /* t7 = B - x3 */ 823 vli_mod_sub(t7, x1, x2, curve_prime, ndigits); 824 /* t4 = (y2 - y1)*(B - x3) */ 825 vli_mod_mult_fast(y2, y2, t7, curve_prime, ndigits); 826 /* t4 = y3 */ 827 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 828 829 /* t7 = (y2 + y1)^2 = F */ 830 vli_mod_square_fast(t7, t5, curve_prime, ndigits); 831 /* t7 = x3' */ 832 vli_mod_sub(t7, t7, t6, curve_prime, ndigits); 833 /* t6 = x3' - B */ 834 vli_mod_sub(t6, t7, x1, curve_prime, ndigits); 835 /* t6 = (y2 + y1)*(x3' - B) */ 836 vli_mod_mult_fast(t6, t6, t5, curve_prime, ndigits); 837 /* t2 = y3' */ 838 vli_mod_sub(y1, t6, y1, curve_prime, ndigits); 839 840 vli_set(x1, t7, ndigits); 841 } 842 843 static void ecc_point_mult(struct ecc_point *result, 844 const struct ecc_point *point, const u64 *scalar, 845 u64 *initial_z, u64 *curve_prime, 846 unsigned int ndigits) 847 { 848 /* R0 and R1 */ 849 u64 rx[2][ndigits]; 850 u64 ry[2][ndigits]; 851 u64 z[ndigits]; 852 int i, nb; 853 int num_bits = vli_num_bits(scalar, ndigits); 854 855 vli_set(rx[1], point->x, ndigits); 856 vli_set(ry[1], point->y, ndigits); 857 858 xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve_prime, 859 ndigits); 860 861 for (i = num_bits - 2; i > 0; i--) { 862 nb = !vli_test_bit(scalar, i); 863 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime, 864 ndigits); 865 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime, 866 ndigits); 867 } 868 869 nb = !vli_test_bit(scalar, 0); 870 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime, 871 ndigits); 872 873 /* Find final 1/Z value. */ 874 /* X1 - X0 */ 875 vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits); 876 /* Yb * (X1 - X0) */ 877 vli_mod_mult_fast(z, z, ry[1 - nb], curve_prime, ndigits); 878 /* xP * Yb * (X1 - X0) */ 879 vli_mod_mult_fast(z, z, point->x, curve_prime, ndigits); 880 881 /* 1 / (xP * Yb * (X1 - X0)) */ 882 vli_mod_inv(z, z, curve_prime, point->ndigits); 883 884 /* yP / (xP * Yb * (X1 - X0)) */ 885 vli_mod_mult_fast(z, z, point->y, curve_prime, ndigits); 886 /* Xb * yP / (xP * Yb * (X1 - X0)) */ 887 vli_mod_mult_fast(z, z, rx[1 - nb], curve_prime, ndigits); 888 /* End 1/Z calculation */ 889 890 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime, ndigits); 891 892 apply_z(rx[0], ry[0], z, curve_prime, ndigits); 893 894 vli_set(result->x, rx[0], ndigits); 895 vli_set(result->y, ry[0], ndigits); 896 } 897 898 static inline void ecc_swap_digits(const u64 *in, u64 *out, 899 unsigned int ndigits) 900 { 901 int i; 902 903 for (i = 0; i < ndigits; i++) 904 out[i] = __swab64(in[ndigits - 1 - i]); 905 } 906 907 int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits, 908 const u64 *private_key, unsigned int private_key_len) 909 { 910 int nbytes; 911 const struct ecc_curve *curve = ecc_get_curve(curve_id); 912 913 if (!private_key) 914 return -EINVAL; 915 916 nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 917 918 if (private_key_len != nbytes) 919 return -EINVAL; 920 921 if (vli_is_zero(private_key, ndigits)) 922 return -EINVAL; 923 924 /* Make sure the private key is in the range [1, n-1]. */ 925 if (vli_cmp(curve->n, private_key, ndigits) != 1) 926 return -EINVAL; 927 928 return 0; 929 } 930 931 /* 932 * ECC private keys are generated using the method of extra random bits, 933 * equivalent to that described in FIPS 186-4, Appendix B.4.1. 934 * 935 * d = (c mod(n–1)) + 1 where c is a string of random bits, 64 bits longer 936 * than requested 937 * 0 <= c mod(n-1) <= n-2 and implies that 938 * 1 <= d <= n-1 939 * 940 * This method generates a private key uniformly distributed in the range 941 * [1, n-1]. 942 */ 943 int ecc_gen_privkey(unsigned int curve_id, unsigned int ndigits, u64 *privkey) 944 { 945 const struct ecc_curve *curve = ecc_get_curve(curve_id); 946 u64 priv[ndigits]; 947 unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 948 unsigned int nbits = vli_num_bits(curve->n, ndigits); 949 int err; 950 951 /* Check that N is included in Table 1 of FIPS 186-4, section 6.1.1 */ 952 if (nbits < 160) 953 return -EINVAL; 954 955 /* 956 * FIPS 186-4 recommends that the private key should be obtained from a 957 * RBG with a security strength equal to or greater than the security 958 * strength associated with N. 959 * 960 * The maximum security strength identified by NIST SP800-57pt1r4 for 961 * ECC is 256 (N >= 512). 962 * 963 * This condition is met by the default RNG because it selects a favored 964 * DRBG with a security strength of 256. 965 */ 966 if (crypto_get_default_rng()) 967 return -EFAULT; 968 969 err = crypto_rng_get_bytes(crypto_default_rng, (u8 *)priv, nbytes); 970 crypto_put_default_rng(); 971 if (err) 972 return err; 973 974 if (vli_is_zero(priv, ndigits)) 975 return -EINVAL; 976 977 /* Make sure the private key is in the range [1, n-1]. */ 978 if (vli_cmp(curve->n, priv, ndigits) != 1) 979 return -EINVAL; 980 981 ecc_swap_digits(priv, privkey, ndigits); 982 983 return 0; 984 } 985 986 int ecc_make_pub_key(unsigned int curve_id, unsigned int ndigits, 987 const u64 *private_key, u64 *public_key) 988 { 989 int ret = 0; 990 struct ecc_point *pk; 991 u64 priv[ndigits]; 992 const struct ecc_curve *curve = ecc_get_curve(curve_id); 993 994 if (!private_key || !curve) { 995 ret = -EINVAL; 996 goto out; 997 } 998 999 ecc_swap_digits(private_key, priv, ndigits); 1000 1001 pk = ecc_alloc_point(ndigits); 1002 if (!pk) { 1003 ret = -ENOMEM; 1004 goto out; 1005 } 1006 1007 ecc_point_mult(pk, &curve->g, priv, NULL, curve->p, ndigits); 1008 if (ecc_point_is_zero(pk)) { 1009 ret = -EAGAIN; 1010 goto err_free_point; 1011 } 1012 1013 ecc_swap_digits(pk->x, public_key, ndigits); 1014 ecc_swap_digits(pk->y, &public_key[ndigits], ndigits); 1015 1016 err_free_point: 1017 ecc_free_point(pk); 1018 out: 1019 return ret; 1020 } 1021 1022 int crypto_ecdh_shared_secret(unsigned int curve_id, unsigned int ndigits, 1023 const u64 *private_key, const u64 *public_key, 1024 u64 *secret) 1025 { 1026 int ret = 0; 1027 struct ecc_point *product, *pk; 1028 u64 *priv, *rand_z; 1029 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1030 1031 if (!private_key || !public_key || !curve) { 1032 ret = -EINVAL; 1033 goto out; 1034 } 1035 1036 priv = kmalloc_array(ndigits, sizeof(*priv), GFP_KERNEL); 1037 if (!priv) { 1038 ret = -ENOMEM; 1039 goto out; 1040 } 1041 1042 rand_z = kmalloc_array(ndigits, sizeof(*rand_z), GFP_KERNEL); 1043 if (!rand_z) { 1044 ret = -ENOMEM; 1045 goto kfree_out; 1046 } 1047 1048 pk = ecc_alloc_point(ndigits); 1049 if (!pk) { 1050 ret = -ENOMEM; 1051 goto kfree_out; 1052 } 1053 1054 product = ecc_alloc_point(ndigits); 1055 if (!product) { 1056 ret = -ENOMEM; 1057 goto err_alloc_product; 1058 } 1059 1060 get_random_bytes(rand_z, ndigits << ECC_DIGITS_TO_BYTES_SHIFT); 1061 1062 ecc_swap_digits(public_key, pk->x, ndigits); 1063 ecc_swap_digits(&public_key[ndigits], pk->y, ndigits); 1064 ecc_swap_digits(private_key, priv, ndigits); 1065 1066 ecc_point_mult(product, pk, priv, rand_z, curve->p, ndigits); 1067 1068 ecc_swap_digits(product->x, secret, ndigits); 1069 1070 if (ecc_point_is_zero(product)) 1071 ret = -EFAULT; 1072 1073 ecc_free_point(product); 1074 err_alloc_product: 1075 ecc_free_point(pk); 1076 kfree_out: 1077 kzfree(priv); 1078 kzfree(rand_z); 1079 out: 1080 return ret; 1081 } 1082