diff options
Diffstat (limited to 'src/crypto/rsa/rsa_impl.c')
-rw-r--r-- | src/crypto/rsa/rsa_impl.c | 184 |
1 files changed, 91 insertions, 93 deletions
diff --git a/src/crypto/rsa/rsa_impl.c b/src/crypto/rsa/rsa_impl.c index d950d50..e14f0f5 100644 --- a/src/crypto/rsa/rsa_impl.c +++ b/src/crypto/rsa/rsa_impl.c @@ -61,8 +61,10 @@ #include <openssl/bn.h> #include <openssl/err.h> #include <openssl/mem.h> +#include <openssl/thread.h> #include "internal.h" +#include "../internal.h" #define OPENSSL_RSA_MAX_MODULUS_BITS 16384 @@ -72,15 +74,9 @@ static int finish(RSA *rsa) { - if (rsa->_method_mod_n != NULL) { - BN_MONT_CTX_free(rsa->_method_mod_n); - } - if (rsa->_method_mod_p != NULL) { - BN_MONT_CTX_free(rsa->_method_mod_p); - } - if (rsa->_method_mod_q != NULL) { - BN_MONT_CTX_free(rsa->_method_mod_q); - } + BN_MONT_CTX_free(rsa->_method_mod_n); + BN_MONT_CTX_free(rsa->_method_mod_p); + BN_MONT_CTX_free(rsa->_method_mod_q); return 1; } @@ -165,13 +161,14 @@ static int encrypt(RSA *rsa, size_t *out_len, uint8_t *out, size_t max_out, } if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) { - if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, - ctx)) { + if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, ctx) == + NULL) { goto err; } } - if (!rsa->meth->bn_mod_exp(result, f, rsa->e, rsa->n, ctx, rsa->_method_mod_n)) { + if (!rsa->meth->bn_mod_exp(result, f, rsa->e, rsa->n, ctx, + rsa->_method_mod_n)) { goto err; } @@ -217,37 +214,20 @@ static BN_BLINDING *rsa_blinding_get(RSA *rsa, unsigned *index_used, uint8_t *new_blindings_inuse; char overflow = 0; - CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING); - if (rsa->num_blindings > 0) { - unsigned i, starting_index; - CRYPTO_THREADID threadid; - - /* We start searching the array at a value based on the - * threadid in order to try avoid bouncing the BN_BLINDING - * values around different threads. It's harmless if - * threadid.val is always set to zero. */ - CRYPTO_THREADID_current(&threadid); - starting_index = threadid.val % rsa->num_blindings; - - for (i = starting_index;;) { - if (rsa->blindings_inuse[i] == 0) { - rsa->blindings_inuse[i] = 1; - ret = rsa->blindings[i]; - *index_used = i; - break; - } - i++; - if (i == rsa->num_blindings) { - i = 0; - } - if (i == starting_index) { - break; - } + CRYPTO_MUTEX_lock_write(&rsa->lock); + + unsigned i; + for (i = 0; i < rsa->num_blindings; i++) { + if (rsa->blindings_inuse[i] == 0) { + rsa->blindings_inuse[i] = 1; + ret = rsa->blindings[i]; + *index_used = i; + break; } } if (ret != NULL) { - CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING); + CRYPTO_MUTEX_unlock(&rsa->lock); return ret; } @@ -256,7 +236,7 @@ static BN_BLINDING *rsa_blinding_get(RSA *rsa, unsigned *index_used, /* We didn't find a free BN_BLINDING to use so increase the length of * the arrays by one and use the newly created element. */ - CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING); + CRYPTO_MUTEX_unlock(&rsa->lock); ret = rsa_setup_blinding(rsa, ctx); if (ret == NULL) { return NULL; @@ -269,7 +249,7 @@ static BN_BLINDING *rsa_blinding_get(RSA *rsa, unsigned *index_used, return ret; } - CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING); + CRYPTO_MUTEX_lock_write(&rsa->lock); new_blindings = OPENSSL_malloc(sizeof(BN_BLINDING *) * (rsa->num_blindings + 1)); @@ -288,24 +268,20 @@ static BN_BLINDING *rsa_blinding_get(RSA *rsa, unsigned *index_used, new_blindings_inuse[rsa->num_blindings] = 1; *index_used = rsa->num_blindings; - if (rsa->blindings != NULL) { - OPENSSL_free(rsa->blindings); - } + OPENSSL_free(rsa->blindings); rsa->blindings = new_blindings; - if (rsa->blindings_inuse != NULL) { - OPENSSL_free(rsa->blindings_inuse); - } + OPENSSL_free(rsa->blindings_inuse); rsa->blindings_inuse = new_blindings_inuse; rsa->num_blindings++; - CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING); + CRYPTO_MUTEX_unlock(&rsa->lock); return ret; err2: OPENSSL_free(new_blindings); err1: - CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING); + CRYPTO_MUTEX_unlock(&rsa->lock); BN_BLINDING_free(ret); return NULL; } @@ -320,9 +296,9 @@ static void rsa_blinding_release(RSA *rsa, BN_BLINDING *blinding, return; } - CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING); + CRYPTO_MUTEX_lock_write(&rsa->lock); rsa->blindings_inuse[blinding_index] = 0; - CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING); + CRYPTO_MUTEX_unlock(&rsa->lock); } /* signing */ @@ -360,8 +336,7 @@ static int sign_raw(RSA *rsa, size_t *out_len, uint8_t *out, size_t max_out, } if (!RSA_private_transform(rsa, out, buf, rsa_size)) { - OPENSSL_PUT_ERROR(RSA, sign_raw, ERR_R_INTERNAL_ERROR); - goto err; + goto err; } *out_len = rsa_size; @@ -400,7 +375,6 @@ static int decrypt(RSA *rsa, size_t *out_len, uint8_t *out, size_t max_out, } if (!RSA_private_transform(rsa, buf, in, rsa_size)) { - OPENSSL_PUT_ERROR(RSA, decrypt, ERR_R_INTERNAL_ERROR); goto err; } @@ -497,8 +471,8 @@ static int verify_raw(RSA *rsa, size_t *out_len, uint8_t *out, size_t max_out, } if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) { - if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, - ctx)) { + if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, ctx) == + NULL) { goto err; } } @@ -601,8 +575,8 @@ static int private_transform(RSA *rsa, uint8_t *out, const uint8_t *in, BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME); if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) { - if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, - ctx)) { + if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, + ctx) == NULL) { goto err; } } @@ -663,18 +637,20 @@ static int mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx) { BN_with_flags(q, rsa->q, BN_FLG_CONSTTIME); if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) { - if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_p, CRYPTO_LOCK_RSA, p, ctx)) { + if (BN_MONT_CTX_set_locked(&rsa->_method_mod_p, &rsa->lock, p, ctx) == + NULL) { goto err; } - if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_q, CRYPTO_LOCK_RSA, q, ctx)) { + if (BN_MONT_CTX_set_locked(&rsa->_method_mod_q, &rsa->lock, q, ctx) == + NULL) { goto err; } } } if (rsa->flags & RSA_FLAG_CACHE_PUBLIC) { - if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, CRYPTO_LOCK_RSA, rsa->n, - ctx)) { + if (BN_MONT_CTX_set_locked(&rsa->_method_mod_n, &rsa->lock, rsa->n, ctx) == + NULL) { goto err; } } @@ -814,65 +790,79 @@ static int keygen(RSA *rsa, int bits, BIGNUM *e_value, BN_GENCB *cb) { bitsq = bits - bitsp; /* We need the RSA components non-NULL */ - if (!rsa->n && ((rsa->n = BN_new()) == NULL)) + if (!rsa->n && ((rsa->n = BN_new()) == NULL)) { goto err; - if (!rsa->d && ((rsa->d = BN_new()) == NULL)) + } + if (!rsa->d && ((rsa->d = BN_new()) == NULL)) { goto err; - if (!rsa->e && ((rsa->e = BN_new()) == NULL)) + } + if (!rsa->e && ((rsa->e = BN_new()) == NULL)) { goto err; - if (!rsa->p && ((rsa->p = BN_new()) == NULL)) + } + if (!rsa->p && ((rsa->p = BN_new()) == NULL)) { goto err; - if (!rsa->q && ((rsa->q = BN_new()) == NULL)) + } + if (!rsa->q && ((rsa->q = BN_new()) == NULL)) { goto err; - if (!rsa->dmp1 && ((rsa->dmp1 = BN_new()) == NULL)) + } + if (!rsa->dmp1 && ((rsa->dmp1 = BN_new()) == NULL)) { goto err; - if (!rsa->dmq1 && ((rsa->dmq1 = BN_new()) == NULL)) + } + if (!rsa->dmq1 && ((rsa->dmq1 = BN_new()) == NULL)) { goto err; - if (!rsa->iqmp && ((rsa->iqmp = BN_new()) == NULL)) + } + if (!rsa->iqmp && ((rsa->iqmp = BN_new()) == NULL)) { goto err; + } BN_copy(rsa->e, e_value); /* generate p and q */ for (;;) { - if (!BN_generate_prime_ex(rsa->p, bitsp, 0, NULL, NULL, cb)) - goto err; - if (!BN_sub(r2, rsa->p, BN_value_one())) - goto err; - if (!BN_gcd(r1, r2, rsa->e, ctx)) + if (!BN_generate_prime_ex(rsa->p, bitsp, 0, NULL, NULL, cb) || + !BN_sub(r2, rsa->p, BN_value_one()) || + !BN_gcd(r1, r2, rsa->e, ctx)) { goto err; - if (BN_is_one(r1)) + } + if (BN_is_one(r1)) { break; - if (!BN_GENCB_call(cb, 2, n++)) + } + if (!BN_GENCB_call(cb, 2, n++)) { goto err; + } } - if (!BN_GENCB_call(cb, 3, 0)) + if (!BN_GENCB_call(cb, 3, 0)) { goto err; + } for (;;) { /* When generating ridiculously small keys, we can get stuck * continually regenerating the same prime values. Check for * this and bail if it happens 3 times. */ unsigned int degenerate = 0; do { - if (!BN_generate_prime_ex(rsa->q, bitsq, 0, NULL, NULL, cb)) + if (!BN_generate_prime_ex(rsa->q, bitsq, 0, NULL, NULL, cb)) { goto err; + } } while ((BN_cmp(rsa->p, rsa->q) == 0) && (++degenerate < 3)); if (degenerate == 3) { ok = 0; /* we set our own err */ OPENSSL_PUT_ERROR(RSA, keygen, RSA_R_KEY_SIZE_TOO_SMALL); goto err; } - if (!BN_sub(r2, rsa->q, BN_value_one())) - goto err; - if (!BN_gcd(r1, r2, rsa->e, ctx)) + if (!BN_sub(r2, rsa->q, BN_value_one()) || + !BN_gcd(r1, r2, rsa->e, ctx)) { goto err; - if (BN_is_one(r1)) + } + if (BN_is_one(r1)) { break; - if (!BN_GENCB_call(cb, 2, n++)) + } + if (!BN_GENCB_call(cb, 2, n++)) { goto err; + } } - if (!BN_GENCB_call(cb, 3, 1)) + if (!BN_GENCB_call(cb, 3, 1)) { goto err; + } if (BN_cmp(rsa->p, rsa->q) < 0) { tmp = rsa->p; rsa->p = rsa->q; @@ -880,39 +870,47 @@ static int keygen(RSA *rsa, int bits, BIGNUM *e_value, BN_GENCB *cb) { } /* calculate n */ - if (!BN_mul(rsa->n, rsa->p, rsa->q, ctx)) + if (!BN_mul(rsa->n, rsa->p, rsa->q, ctx)) { goto err; + } /* calculate d */ - if (!BN_sub(r1, rsa->p, BN_value_one())) + if (!BN_sub(r1, rsa->p, BN_value_one())) { goto err; /* p-1 */ - if (!BN_sub(r2, rsa->q, BN_value_one())) + } + if (!BN_sub(r2, rsa->q, BN_value_one())) { goto err; /* q-1 */ - if (!BN_mul(r0, r1, r2, ctx)) + } + if (!BN_mul(r0, r1, r2, ctx)) { goto err; /* (p-1)(q-1) */ + } pr0 = &local_r0; BN_with_flags(pr0, r0, BN_FLG_CONSTTIME); - if (!BN_mod_inverse(rsa->d, rsa->e, pr0, ctx)) + if (!BN_mod_inverse(rsa->d, rsa->e, pr0, ctx)) { goto err; /* d */ + } /* set up d for correct BN_FLG_CONSTTIME flag */ d = &local_d; BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME); /* calculate d mod (p-1) */ - if (!BN_mod(rsa->dmp1, d, r1, ctx)) + if (!BN_mod(rsa->dmp1, d, r1, ctx)) { goto err; + } /* calculate d mod (q-1) */ - if (!BN_mod(rsa->dmq1, d, r2, ctx)) + if (!BN_mod(rsa->dmq1, d, r2, ctx)) { goto err; + } /* calculate inverse of q mod p */ p = &local_p; BN_with_flags(p, rsa->p, BN_FLG_CONSTTIME); - if (!BN_mod_inverse(rsa->iqmp, rsa->q, p, ctx)) + if (!BN_mod_inverse(rsa->iqmp, rsa->q, p, ctx)) { goto err; + } ok = 1; |