diff --git a/library/rsa.c b/library/rsa.c index db0b0f74f..32a26500e 100644 --- a/library/rsa.c +++ b/library/rsa.c @@ -28,6 +28,7 @@ #if defined(MBEDTLS_RSA_C) #include "mbedtls/rsa.h" +#include "bignum_core.h" #include "rsa_alt_helpers.h" #include "mbedtls/oid.h" #include "mbedtls/platform_util.h" @@ -969,6 +970,40 @@ cleanup: return ret; } +/* + * Unblind + * T = T * Vf mod N + */ +int rsa_unblind(mbedtls_mpi* T, mbedtls_mpi* Vf, mbedtls_mpi* N) +{ + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + const mbedtls_mpi_uint mm = mbedtls_mpi_core_montmul_init(N->p); + const size_t nlimbs = N->n; + const size_t tlimbs = mbedtls_mpi_core_montmul_working_limbs(nlimbs); + mbedtls_mpi RR, M_T; + + mbedtls_mpi_init(&RR); + mbedtls_mpi_init(&M_T); + + MBEDTLS_MPI_CHK(mbedtls_mpi_core_get_mont_r2_unsafe(&RR, N)); + MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&M_T, tlimbs)); + + MBEDTLS_MPI_CHK(mbedtls_mpi_grow(T, nlimbs)); + MBEDTLS_MPI_CHK(mbedtls_mpi_grow(Vf, nlimbs)); + + // T = T * R mod N + mbedtls_mpi_core_to_mont_rep(T->p, T->p, N->p, nlimbs, mm, RR.p, M_T.p); + // T = T * Vf mod N + mbedtls_mpi_core_montmul(T->p, T->p, Vf->p, nlimbs, N->p, nlimbs, mm, M_T.p); + +cleanup: + + mbedtls_mpi_free(&RR); + mbedtls_mpi_free(&M_T); + + return ret; +} + /* * Exponent blinding supposed to prevent side-channel attacks using multiple * traces of measurements to recover the RSA key. The more collisions are there, @@ -1160,8 +1195,7 @@ int mbedtls_rsa_private(mbedtls_rsa_context *ctx, * Unblind * T = T * Vf mod N */ - MBEDTLS_MPI_CHK(mbedtls_mpi_mul_mpi(&T, &T, &ctx->Vf)); - MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(&T, &T, &ctx->N)); + MBEDTLS_MPI_CHK(rsa_unblind(&T, &ctx->Vf, &ctx->N)); /* Verify the result to prevent glitching attacks. */ MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod(&C, &T, &ctx->E,