xref: /openbmc/google-misc/subprojects/libcr51sign/src/libcr51sign_support.c (revision d89f889b1a0c8e0ff7321035eae5ddd4cbc54dba)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <libcr51sign/libcr51sign_support.h>
18 #include <openssl/bio.h>
19 #include <openssl/bn.h>
20 #include <openssl/err.h>
21 #include <openssl/evp.h>
22 #include <openssl/pem.h>
23 #include <openssl/rsa.h>
24 #include <openssl/sha.h>
25 #include <stdint.h>
26 #include <stdio.h>
27 #include <string.h>
28 
29 #ifdef __cplusplus
30 extern "C"
31 {
32 #endif
33 
34 #ifndef USER_PRINT
35 #define CPRINTS(ctx, ...) fprintf(stderr, __VA_ARGS__)
36 #endif
37 
38 // @func hash_init get ready to compute a hash
39 //
40 // @param[in] ctx - context struct
41 // @param[in] hash_type - type of hash function to use
42 //
43 // @return nonzero on error, zero on success
44 
45 int hash_init(const void* ctx, enum hash_type type)
46 {
47     struct libcr51sign_ctx* context = (struct libcr51sign_ctx*)ctx;
48     struct hash_ctx* hash_context = (struct hash_ctx*)context->priv;
49     hash_context->hash_type = type;
50     if (type == HASH_SHA2_256)
51     { // SHA256_Init returns 1
52         SHA256_Init(&hash_context->sha256_ctx);
53     }
54     else if (type == HASH_SHA2_512)
55     {
56         SHA512_Init(&hash_context->sha512_ctx);
57     }
58     else
59     {
60         return LIBCR51SIGN_ERROR_INVALID_HASH_TYPE;
61     }
62 
63     return LIBCR51SIGN_SUCCESS;
64 }
65 
66 // @func hash_update add data to the hash
67 //
68 // @param[in] ctx - context struct
69 // @param[in] buf - data to add to hash
70 // @param[in] count - number of bytes of data to add
71 //
72 // @return nonzero on error, zero on success
73 
74 int hash_update(void* ctx, const uint8_t* data, size_t size)
75 {
76     if (size == 0)
77         return LIBCR51SIGN_SUCCESS;
78     struct libcr51sign_ctx* context = (struct libcr51sign_ctx*)ctx;
79     struct hash_ctx* hash_context = (struct hash_ctx*)context->priv;
80 
81     if (hash_context->hash_type == HASH_SHA2_256)
82     { // SHA256_Update returns 1
83         SHA256_Update(&hash_context->sha256_ctx, data, size);
84     }
85     else if (hash_context->hash_type == HASH_SHA2_512)
86     {
87         SHA512_Update(&hash_context->sha512_ctx, data, size);
88     }
89     else
90     {
91         return LIBCR51SIGN_ERROR_INVALID_HASH_TYPE;
92     }
93 
94     return LIBCR51SIGN_SUCCESS;
95 }
96 
97 // @func hash_final finish hash calculation
98 //
99 // @param[in] ctx - context struct
100 // @param[out] hash - buffer to write hash to (guaranteed to be big enough)
101 //
102 // @return nonzero on error, zero on success
103 
104 int hash_final(void* ctx, uint8_t* hash)
105 {
106     int rv;
107     struct libcr51sign_ctx* context = (struct libcr51sign_ctx*)ctx;
108     struct hash_ctx* hash_context = (struct hash_ctx*)context->priv;
109 
110     if (hash_context->hash_type == HASH_SHA2_256)
111     {
112         rv = SHA256_Final(hash, &hash_context->sha256_ctx);
113     }
114     else if (hash_context->hash_type == HASH_SHA2_512)
115     {
116         rv = SHA512_Final(hash, &hash_context->sha512_ctx);
117     }
118     else
119     {
120         return LIBCR51SIGN_ERROR_INVALID_HASH_TYPE;
121     }
122 
123     if (rv)
124     {
125         return LIBCR51SIGN_SUCCESS;
126     }
127 
128     return LIBCR51SIGN_ERROR_RUNTIME_FAILURE;
129 }
130 
131 // @func verify check that the signature is valid for given hashed data
132 //
133 // @param[in] ctx - context struct
134 // @param[in] scheme - type of signature, hash, etc.
135 // @param[in] sig - signature blob
136 // @param[in] sig_len - length of signature in bytes
137 // @param[in] data - pre-hashed data to verify
138 // @param[in] data_len - length of hashed data in bytes
139 //
140 // verify_signature expects RSA public key file path in ctx->key_ring
141 // @return nonzero on error, zero on success
142 
143 int verify_signature(const void* ctx, enum signature_scheme sig_scheme,
144                      const uint8_t* sig, size_t sig_len, const uint8_t* data,
145                      size_t data_len)
146 {
147     // By default returns error.
148     int rv = LIBCR51SIGN_ERROR_INVALID_ARGUMENT;
149 
150     CPRINTS(ctx, "sig_len %zu sig: ", sig_len);
151     for (size_t i = 0; i < sig_len; i++)
152     {
153         CPRINTS(ctx, "%x", sig[i]);
154     }
155     CPRINTS(ctx, "\n");
156 
157     struct libcr51sign_ctx* lctx = (struct libcr51sign_ctx*)ctx;
158     FILE* fp = fopen(lctx->keyring, "r");
159     RSA *rsa = NULL, *pub_rsa = NULL;
160     EVP_PKEY* pkey = NULL;
161     BIO* bio = BIO_new(BIO_s_mem());
162     if (!fp)
163     {
164         CPRINTS(ctx, "fopen failed\n");
165         goto clean_up;
166     }
167 
168     pkey = PEM_read_PUBKEY(fp, 0, 0, 0);
169     if (!pkey)
170     {
171         CPRINTS(ctx, "Read public key failed\n");
172         goto clean_up;
173     }
174 
175     rsa = EVP_PKEY_get1_RSA(pkey);
176     if (!rsa)
177     {
178         goto clean_up;
179     }
180     pub_rsa = RSAPublicKey_dup(rsa);
181     if (!RSA_print(bio, pub_rsa, 2))
182     {
183         CPRINTS(ctx, "RSA print failed\n");
184     }
185     if (!pub_rsa)
186     {
187         CPRINTS(ctx, "no pub RSA\n");
188         goto clean_up;
189     }
190     CPRINTS(ctx, "public RSA\n");
191     char buffer[1024] = {};
192     while (BIO_read(bio, buffer, sizeof(buffer) - 1) > 0)
193     {
194         CPRINTS(ctx, " %s", buffer);
195     }
196     enum hash_type hash_type;
197     rv = get_hash_type_from_signature(sig_scheme, &hash_type);
198     if (rv != LIBCR51SIGN_SUCCESS)
199     {
200         CPRINTS(ctx, "Invalid hash_type!\n");
201         goto clean_up;
202     }
203     int hash_nid = -1;
204     if (hash_type == HASH_SHA2_256)
205     {
206         hash_nid = NID_sha256;
207     }
208     else if (hash_type == HASH_SHA2_512)
209     {
210         hash_nid = NID_sha512;
211     }
212     else
213     {
214         rv = LIBCR51SIGN_ERROR_INVALID_HASH_TYPE;
215         goto clean_up;
216     }
217 
218     int ret = RSA_verify(hash_nid, data, data_len, sig, sig_len, pub_rsa);
219     // OpenSSL RSA_verify returns 1 on success and 0 on failure
220     if (!ret)
221     {
222         CPRINTS(ctx, "OPENSSL_ERROR: %s\n",
223                 ERR_error_string(ERR_get_error(), NULL));
224         rv = LIBCR51SIGN_ERROR_RUNTIME_FAILURE;
225         goto clean_up;
226     }
227     rv = LIBCR51SIGN_SUCCESS;
228     CPRINTS(ctx, "sig: ");
229     for (size_t i = 0; i < sig_len; i++)
230     {
231         CPRINTS(ctx, "%x", sig[i]);
232     }
233     CPRINTS(ctx, "\n");
234 
235     CPRINTS(ctx, "data: ");
236     for (size_t i = 0; i < data_len; i++)
237     {
238         CPRINTS(ctx, "%x", data[i]);
239     }
240     CPRINTS(ctx, "\n");
241 
242     const unsigned rsa_size = RSA_size(pub_rsa);
243     CPRINTS(ctx, "rsa size %d sig_len %d\n", rsa_size, (uint32_t)sig_len);
244 
245 clean_up:
246     if (fp)
247     {
248         fclose(fp);
249     }
250     EVP_PKEY_free(pkey);
251     RSA_free(rsa);
252     RSA_free(pub_rsa);
253     BIO_free(bio);
254     return rv;
255 }
256 
257 // @func Verify RSA signature with modulus and exponent
258 // @param[in]  ctx - context struct
259 // @param[in]  sig_scheme - signature scheme
260 // @param[in]  modulus - modulus of the RSA key, MSB (big-endian)
261 // @param[in]  modulus_len - length of modulus in bytes
262 // @param[in]  exponent - exponent of the RSA key
263 // @param[in]  sig - signature blob
264 // @param[in]  sig_len - length of signature in bytes
265 // @param[in]  digest - digest to verify
266 // @param[in]  digest_len - digest size
267 //
268 // @return true: if the signature is verified
269 //         false: otherwise
270 __attribute__((nonnull)) bool verify_rsa_signature_with_modulus_and_exponent(
271     const void* ctx, enum signature_scheme sig_scheme, const uint8_t* modulus,
272     int modulus_len, uint32_t exponent, const uint8_t* sig, int sig_len,
273     const uint8_t* digest, int digest_len)
274 {
275     RSA* rsa = NULL;
276     BIGNUM* n = NULL;
277     BIGNUM* e = NULL;
278     int ret = 0;
279     int hash_nid = NID_undef;
280     int expected_modulus_bits = 0;
281     int expected_digest_len = 0;
282 
283     CPRINTS(ctx, "%s: sig_scheme = %d\n", __FUNCTION__, sig_scheme);
284     // Determine hash NID and expected modulus size based on signature_scheme
285     switch (sig_scheme)
286     {
287         case SIGNATURE_RSA2048_PKCS15:
288             expected_modulus_bits = 2048;
289             hash_nid = NID_sha256;
290             expected_digest_len = SHA256_DIGEST_LENGTH;
291             break;
292         case SIGNATURE_RSA3072_PKCS15:
293             expected_modulus_bits = 3072;
294             hash_nid = NID_sha256;
295             expected_digest_len = SHA256_DIGEST_LENGTH;
296             break;
297         case SIGNATURE_RSA4096_PKCS15:
298             expected_modulus_bits = 4096;
299             hash_nid = NID_sha256;
300             expected_digest_len = SHA256_DIGEST_LENGTH;
301             break;
302         case SIGNATURE_RSA4096_PKCS15_SHA512:
303             expected_modulus_bits = 4096;
304             hash_nid = NID_sha512;
305             expected_digest_len = SHA512_DIGEST_LENGTH;
306             break;
307         default:
308             CPRINTS(ctx, "%s: Unsupported signature scheme.\n", __FUNCTION__);
309             return false;
310     }
311 
312     // Input validation: Check digest length
313     if (digest_len != expected_digest_len)
314     {
315         CPRINTS(
316             ctx,
317             "%s: Mismatch in expected digest length (%d) and actual (%d).\n",
318             __FUNCTION__, expected_digest_len, digest_len);
319         return false;
320     }
321 
322     // 1. Create a new RSA object
323     rsa = RSA_new();
324     if (rsa == NULL)
325     {
326         CPRINTS(ctx, "%s:Error creating RSA object: %s\n", __FUNCTION__,
327                 ERR_error_string(ERR_get_error(), NULL));
328         goto err;
329     }
330 
331     // 2. Convert raw modulus and exponent to BIGNUMs
332     n = BN_bin2bn(modulus, modulus_len, NULL);
333     if (n == NULL)
334     {
335         CPRINTS(ctx, "%s:Error converting modulus to BIGNUM: %s\n",
336                 __FUNCTION__, ERR_error_string(ERR_get_error(), NULL));
337         goto err;
338     }
339 
340     e = BN_new();
341     if (e == NULL)
342     {
343         CPRINTS(ctx, "%s: Error creating BIGNUM for exponent: %s\n",
344                 __FUNCTION__, ERR_error_string(ERR_get_error(), NULL));
345         goto err;
346     }
347     if (!BN_set_word(e, exponent))
348     {
349         CPRINTS(ctx, "%s: Error setting exponent word: %s\n", __FUNCTION__,
350                 ERR_error_string(ERR_get_error(), NULL));
351         goto err;
352     }
353 
354     // Set the public key components. RSA_set0_key takes ownership of n and e.
355     if (!RSA_set0_key(rsa, n, e, NULL))
356     { // For public key, d is NULL
357         CPRINTS(ctx, "%s: Error setting RSA key components: %s\n", __FUNCTION__,
358                 ERR_error_string(ERR_get_error(), NULL));
359         goto err;
360     }
361     n = NULL; // Clear pointers to prevent double-free
362     e = NULL;
363 
364     if (RSA_bits(rsa) != expected_modulus_bits)
365     {
366         CPRINTS(
367             ctx,
368             "%s: Error: RSA key size (%d bits) does not match expected size for "
369             "scheme (%d bits).\n",
370             __FUNCTION__, RSA_bits(rsa), expected_modulus_bits);
371         goto err;
372     }
373 
374     // Input validation: Signature length must match modulus length
375     if (sig_len != RSA_size(rsa))
376     {
377         CPRINTS(
378             ctx,
379             "%s: Error: Signature length (%d) does not match RSA key size (%d).\n",
380             __FUNCTION__, sig_len, RSA_size(rsa));
381         goto err;
382     }
383 
384     // 3. Verify the signature
385     // RSA_verify handles the decryption, PKCS#1 v1.5 padding check, and hash
386     // comparison internally.
387     CPRINTS(ctx, "%s: RSA_verify\n", __FUNCTION__);
388     CPRINTS(ctx, "%s: hash_nid %d\n", __FUNCTION__, hash_nid);
389     CPRINTS(ctx, "%s: digest_len  %d, digest: \n", __FUNCTION__, digest_len);
390     for (int i = 0; i < digest_len; i++)
391     {
392         CPRINTS(ctx, "%x", digest[i]);
393     }
394     CPRINTS(ctx, "\n");
395 
396     CPRINTS(ctx, "%s: sig_len %d, sig: \n", __FUNCTION__, sig_len);
397     for (int i = 0; i < sig_len; i++)
398     {
399         CPRINTS(ctx, "%x", sig[i]);
400     }
401     CPRINTS(ctx, "\n");
402 
403     ret = RSA_verify(hash_nid, digest, digest_len, sig, sig_len, rsa);
404 
405     if (ret == 1)
406     {
407         CPRINTS(ctx, "%s: Signature verification successful!\n", __FUNCTION__);
408     }
409     else
410     {
411         CPRINTS(ctx, "%s: Signature verification failed: %s\n", __FUNCTION__,
412                 ERR_error_string(ERR_get_error(), NULL));
413     }
414 
415 err:
416     RSA_free(rsa); // Frees n and e if RSA_set0_key successfully took ownership
417     BN_free(n);    // Only if RSA_set0_key failed or was not called
418     BN_free(e);    // Only if RSA_set0_key failed or was not called
419     return (ret == 1);
420     (void)ctx;     // make compiler happy when CPRINTS is null statemenet
421 }
422 
423 #ifdef __cplusplus
424 } //  extern "C"
425 #endif
426