diff --git a/library/pk_wrap.c b/library/pk_wrap.c index db6274cbf..bdc0f3927 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -68,7 +68,8 @@ static int rsa_can_do( mbedtls_pk_type_t type ) static size_t rsa_get_bitlen( const void *ctx ) { - return( 8 * ((const mbedtls_rsa_context *) ctx)->len ); + const mbedtls_rsa_context * rsa = (const mbedtls_rsa_context *) ctx; + return( 8 * mbedtls_rsa_get_len( rsa ) ); } static int rsa_verify_wrap( void *ctx, mbedtls_md_type_t md_alg, @@ -76,21 +77,23 @@ static int rsa_verify_wrap( void *ctx, mbedtls_md_type_t md_alg, const unsigned char *sig, size_t sig_len ) { int ret; + mbedtls_rsa_context * rsa = (mbedtls_rsa_context *) ctx; + size_t rsa_len = mbedtls_rsa_get_len( rsa ); #if defined(MBEDTLS_HAVE_INT64) if( md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len ) return( MBEDTLS_ERR_PK_BAD_INPUT_DATA ); #endif /* MBEDTLS_HAVE_INT64 */ - if( sig_len < ((mbedtls_rsa_context *) ctx)->len ) + if( sig_len < rsa_len ) return( MBEDTLS_ERR_RSA_VERIFY_FAILED ); - if( ( ret = mbedtls_rsa_pkcs1_verify( (mbedtls_rsa_context *) ctx, NULL, NULL, + if( ( ret = mbedtls_rsa_pkcs1_verify( rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, md_alg, (unsigned int) hash_len, hash, sig ) ) != 0 ) return( ret ); - if( sig_len > ((mbedtls_rsa_context *) ctx)->len ) + if( sig_len > rsa_len ) return( MBEDTLS_ERR_PK_SIG_LEN_MISMATCH ); return( 0 ); @@ -101,14 +104,16 @@ static int rsa_sign_wrap( void *ctx, mbedtls_md_type_t md_alg, unsigned char *sig, size_t *sig_len, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) { + mbedtls_rsa_context * rsa = (mbedtls_rsa_context *) ctx; + #if defined(MBEDTLS_HAVE_INT64) if( md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len ) return( MBEDTLS_ERR_PK_BAD_INPUT_DATA ); #endif /* MBEDTLS_HAVE_INT64 */ - *sig_len = ((mbedtls_rsa_context *) ctx)->len; + *sig_len = mbedtls_rsa_get_len( rsa ); - return( mbedtls_rsa_pkcs1_sign( (mbedtls_rsa_context *) ctx, f_rng, p_rng, MBEDTLS_RSA_PRIVATE, + return( mbedtls_rsa_pkcs1_sign( rsa, f_rng, p_rng, MBEDTLS_RSA_PRIVATE, md_alg, (unsigned int) hash_len, hash, sig ) ); } @@ -117,10 +122,12 @@ static int rsa_decrypt_wrap( void *ctx, unsigned char *output, size_t *olen, size_t osize, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) { - if( ilen != ((mbedtls_rsa_context *) ctx)->len ) + mbedtls_rsa_context * rsa = (mbedtls_rsa_context *) ctx; + + if( ilen != mbedtls_rsa_get_len( rsa ) ) return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA ); - return( mbedtls_rsa_pkcs1_decrypt( (mbedtls_rsa_context *) ctx, f_rng, p_rng, + return( mbedtls_rsa_pkcs1_decrypt( rsa, f_rng, p_rng, MBEDTLS_RSA_PRIVATE, olen, input, output, osize ) ); } @@ -129,13 +136,14 @@ static int rsa_encrypt_wrap( void *ctx, unsigned char *output, size_t *olen, size_t osize, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) { - *olen = ((mbedtls_rsa_context *) ctx)->len; + mbedtls_rsa_context * rsa = (mbedtls_rsa_context *) ctx; + *olen = mbedtls_rsa_get_len( rsa ); if( *olen > osize ) return( MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE ); - return( mbedtls_rsa_pkcs1_encrypt( (mbedtls_rsa_context *) ctx, - f_rng, p_rng, MBEDTLS_RSA_PUBLIC, ilen, input, output ) ); + return( mbedtls_rsa_pkcs1_encrypt( rsa, f_rng, p_rng, MBEDTLS_RSA_PUBLIC, + ilen, input, output ) ); } static int rsa_check_pair_wrap( const void *pub, const void *prv )