diff --git a/library/bignum_mod.h b/library/bignum_mod.h index 432cd0070..342bf17b5 100644 --- a/library/bignum_mod.h +++ b/library/bignum_mod.h @@ -67,7 +67,8 @@ void mbedtls_mpi_mod_residue_release( mbedtls_mpi_mod_residue *r ); int mbedtls_mpi_mod_residue_setup( mbedtls_mpi_mod_residue *r, mbedtls_mpi_mod_modulus *m, - mbedtls_mpi_uint *p ); + mbedtls_mpi_uint *p, + size_t pn ); void mbedtls_mpi_mod_modulus_init( mbedtls_mpi_mod_modulus *m ); diff --git a/library/bignum_new.c b/library/bignum_new.c index 04f604923..60f50ada8 100644 --- a/library/bignum_new.c +++ b/library/bignum_new.c @@ -29,6 +29,7 @@ #include "bignum_core.h" #include "bignum_mod.h" #include "bignum_mod_raw.h" +#include "constant_time_internal.h" #if defined(MBEDTLS_PLATFORM_C) #include "mbedtls/platform.h" @@ -92,13 +93,17 @@ void mbedtls_mpi_mod_residue_release( mbedtls_mpi_mod_residue *r ) int mbedtls_mpi_mod_residue_setup( mbedtls_mpi_mod_residue *r, mbedtls_mpi_mod_modulus *m, - mbedtls_mpi_uint *X ) + mbedtls_mpi_uint *p, + size_t pn ) { - if( X == NULL || m == NULL || r == NULL || X >= m->p) + if( p == NULL || m == NULL || r == NULL ) + return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); + + if( pn < m->n || !mbedtls_mpi_core_lt_ct( m->p, p, pn ) ) return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); r->n = m->n; - r->p = X; + r->p = p; return( 0 ); } @@ -447,16 +452,28 @@ int mbedtls_mpi_mod_raw_read( mbedtls_mpi_uint *X, unsigned char *buf, size_t buflen ) { + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + if( m->ext_rep & MBEDTLS_MPI_MOD_EXT_REP_LE ) - return mbedtls_mpi_core_read_le( X, m->n, buf, buflen ); + ret = mbedtls_mpi_core_read_le( X, m->n, buf, buflen ); else if( m->ext_rep & MBEDTLS_MPI_MOD_EXT_REP_BE ) - return mbedtls_mpi_core_read_be( X, m->n, buf, buflen ); - + ret = mbedtls_mpi_core_read_be( X, m->n, buf, buflen ); else return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); - return( 0 ); + if( ret != 0 ) + goto cleanup; + + if( !mbedtls_mpi_core_lt_ct( X, m->p, m->n ) ) + { + ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA; + goto cleanup; + } + +cleanup: + + return( ret ); } int mbedtls_mpi_mod_raw_write( mbedtls_mpi_uint *X,