From dffc710008b5f2de69087c7a194c1170a5fa0fa9 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Thu, 10 Jun 2021 15:34:15 +0200 Subject: [PATCH] Test the validity of the sign bit after constructing an MPI object This is mostly to look for cases where the sign bit may have been left at 0 after zerozing memory, or a value of 0 with the sign bit set to -11. Both of these mostly work fine, so they can go otherwise undetected by unit tests, but they can break when certain combinations of functions are used. Signed-off-by: Gilles Peskine --- tests/suites/test_suite_mpi.function | 63 ++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/suites/test_suite_mpi.function b/tests/suites/test_suite_mpi.function index eaede01b4..d45d3dcf8 100644 --- a/tests/suites/test_suite_mpi.function +++ b/tests/suites/test_suite_mpi.function @@ -6,6 +6,18 @@ #define MPI_MAX_BITS_LARGER_THAN_792 #endif +/* Check the validity of the sign bit in an MPI object. Reject representations + * that are not supported by the rest of the library and indicate a bug when + * constructing the value. */ +static int sign_is_valid( const mbedtls_mpi *X ) +{ + if( X->s != 1 && X->s != -1 ) + return( 0 ); // invalid sign bit, e.g. 0 + if( mbedtls_mpi_bitlen( X ) == 0 && X->s != 1 ) + return( 0 ); // negative zero + return( 1 ); +} + typedef struct mbedtls_test_mpi_random { data_t *data; @@ -150,6 +162,7 @@ void mpi_read_write_string( int radix_X, char * input_X, int radix_A, TEST_ASSERT( mbedtls_mpi_read_string( &X, radix_X, input_X ) == result_read ); if( result_read == 0 ) { + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_write_string( &X, radix_A, str, output_size, &len ) == result_write ); if( result_write == 0 ) { @@ -174,6 +187,7 @@ void mbedtls_mpi_read_binary( data_t * buf, int radix_A, char * input_A ) TEST_ASSERT( mbedtls_mpi_read_binary( &X, buf->x, buf->len ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_write_string( &X, radix_A, str, sizeof( str ), &len ) == 0 ); TEST_ASSERT( strcmp( (char *) str, input_A ) == 0 ); @@ -193,6 +207,7 @@ void mbedtls_mpi_read_binary_le( data_t * buf, int radix_A, char * input_A ) TEST_ASSERT( mbedtls_mpi_read_binary_le( &X, buf->x, buf->len ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_write_string( &X, radix_A, str, sizeof( str ), &len ) == 0 ); TEST_ASSERT( strcmp( (char *) str, input_A ) == 0 ); @@ -287,6 +302,7 @@ void mbedtls_mpi_read_file( int radix_X, char * input_file, if( result == 0 ) { + TEST_ASSERT( sign_is_valid( &X ) ); buflen = mbedtls_mpi_size( &X ); TEST_ASSERT( mbedtls_mpi_write_binary( &X, buf, buflen ) == 0 ); @@ -357,6 +373,7 @@ void mbedtls_mpi_set_bit( int radix_X, char * input_X, int pos, int val, if( result == 0 ) { + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &Y ) == 0 ); } @@ -404,6 +421,7 @@ void mbedtls_mpi_gcd( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( mbedtls_test_read_mpi( &Y, radix_Y, input_Y ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_gcd( &Z, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); exit: @@ -492,6 +510,7 @@ void mbedtls_mpi_copy_sint( int input_X, int input_Y ) TEST_ASSERT( mbedtls_mpi_lset( &Y, input_Y ) == 0 ); TEST_ASSERT( mbedtls_mpi_copy( &Y, &X ) == 0 ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_int( &X, input_X ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_int( &Y, input_X ) == 0 ); @@ -512,6 +531,7 @@ void mbedtls_mpi_copy_binary( data_t *input_X, data_t *input_Y ) TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &X0 ) == 0 ); TEST_ASSERT( mbedtls_mpi_copy( &Y, &X ) == 0 ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &X0 ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &X0 ) == 0 ); @@ -566,9 +586,11 @@ void mbedtls_mpi_safe_cond_assign( int x_sign, char * x_str, int y_sign, TEST_ASSERT( mbedtls_mpi_copy( &XX, &X ) == 0 ); TEST_ASSERT( mbedtls_mpi_safe_cond_assign( &X, &Y, 0 ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &XX ) == 0 ); TEST_ASSERT( mbedtls_mpi_safe_cond_assign( &X, &Y, 1 ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &Y ) == 0 ); exit: @@ -594,10 +616,14 @@ void mbedtls_mpi_safe_cond_swap( int x_sign, char * x_str, int y_sign, TEST_ASSERT( mbedtls_mpi_copy( &YY, &Y ) == 0 ); TEST_ASSERT( mbedtls_mpi_safe_cond_swap( &X, &Y, 0 ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &XX ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &YY ) == 0 ); TEST_ASSERT( mbedtls_mpi_safe_cond_swap( &X, &Y, 1 ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &XX ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &YY ) == 0 ); @@ -619,6 +645,8 @@ void mbedtls_mpi_swap_sint( int input_X, int input_Y ) TEST_ASSERT( mbedtls_mpi_cmp_int( &Y, input_Y ) == 0 ); mbedtls_mpi_swap( &X, &Y ); + TEST_ASSERT( sign_is_valid( &X ) ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_int( &X, input_Y ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_int( &Y, input_X ) == 0 ); @@ -640,6 +668,8 @@ void mbedtls_mpi_swap_binary( data_t *input_X, data_t *input_Y ) TEST_ASSERT( mbedtls_mpi_read_binary( &Y0, input_Y->x, input_Y->len ) == 0 ); mbedtls_mpi_swap( &X, &Y ); + TEST_ASSERT( sign_is_valid( &X ) ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &Y0 ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &X0 ) == 0 ); @@ -659,6 +689,7 @@ void mpi_swap_self( data_t *input_X ) TEST_ASSERT( mbedtls_mpi_read_binary( &X0, input_X->x, input_X->len ) == 0 ); mbedtls_mpi_swap( &X, &X ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &X0 ) == 0 ); exit: @@ -677,15 +708,18 @@ void mbedtls_mpi_add_mpi( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( mbedtls_test_read_mpi( &Y, radix_Y, input_Y ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_add_mpi( &Z, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); /* result == first operand */ TEST_ASSERT( mbedtls_mpi_add_mpi( &X, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); /* result == second operand */ TEST_ASSERT( mbedtls_mpi_add_mpi( &Y, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &A ) == 0 ); exit: @@ -705,13 +739,16 @@ void mbedtls_mpi_add_mpi_inplace( int radix_X, char * input_X, int radix_A, TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_mpi_sub_abs( &X, &X, &X ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_int( &X, 0 ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_mpi_add_abs( &X, &X, &X ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_mpi_add_mpi( &X, &X, &X ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); exit: @@ -731,15 +768,18 @@ void mbedtls_mpi_add_abs( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( mbedtls_test_read_mpi( &Y, radix_Y, input_Y ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_add_abs( &Z, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); /* result == first operand */ TEST_ASSERT( mbedtls_mpi_add_abs( &X, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); /* result == second operand */ TEST_ASSERT( mbedtls_mpi_add_abs( &Y, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &A ) == 0 ); exit: @@ -757,6 +797,7 @@ void mbedtls_mpi_add_int( int radix_X, char * input_X, int input_Y, TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_add_int( &Z, &X, input_Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); exit: @@ -775,15 +816,18 @@ void mbedtls_mpi_sub_mpi( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( mbedtls_test_read_mpi( &Y, radix_Y, input_Y ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_sub_mpi( &Z, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); /* result == first operand */ TEST_ASSERT( mbedtls_mpi_sub_mpi( &X, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); /* result == second operand */ TEST_ASSERT( mbedtls_mpi_sub_mpi( &Y, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Y ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &A ) == 0 ); exit: @@ -806,17 +850,20 @@ void mbedtls_mpi_sub_abs( int radix_X, char * input_X, int radix_Y, res = mbedtls_mpi_sub_abs( &Z, &X, &Y ); TEST_ASSERT( res == sub_result ); + TEST_ASSERT( sign_is_valid( &Z ) ); if( res == 0 ) TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); /* result == first operand */ TEST_ASSERT( mbedtls_mpi_sub_abs( &X, &X, &Y ) == sub_result ); + TEST_ASSERT( sign_is_valid( &X ) ); if( sub_result == 0 ) TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); /* result == second operand */ TEST_ASSERT( mbedtls_mpi_sub_abs( &Y, &X, &Y ) == sub_result ); + TEST_ASSERT( sign_is_valid( &Y ) ); if( sub_result == 0 ) TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &A ) == 0 ); @@ -835,6 +882,7 @@ void mbedtls_mpi_sub_int( int radix_X, char * input_X, int input_Y, TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_sub_int( &Z, &X, input_Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); exit: @@ -853,6 +901,7 @@ void mbedtls_mpi_mul_mpi( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( mbedtls_test_read_mpi( &Y, radix_Y, input_Y ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_mul_mpi( &Z, &X, &Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); exit: @@ -871,6 +920,7 @@ void mbedtls_mpi_mul_int( int radix_X, char * input_X, int input_Y, TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_mul_int( &Z, &X, input_Y ) == 0 ); + TEST_ASSERT( sign_is_valid( &Z ) ); if( strcmp( result_comparison, "==" ) == 0 ) TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); else if( strcmp( result_comparison, "!=" ) == 0 ) @@ -901,6 +951,8 @@ void mbedtls_mpi_div_mpi( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( res == div_result ); if( res == 0 ) { + TEST_ASSERT( sign_is_valid( &Q ) ); + TEST_ASSERT( sign_is_valid( &R ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Q, &A ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &R, &B ) == 0 ); } @@ -928,6 +980,8 @@ void mbedtls_mpi_div_int( int radix_X, char * input_X, int input_Y, TEST_ASSERT( res == div_result ); if( res == 0 ) { + TEST_ASSERT( sign_is_valid( &Q ) ); + TEST_ASSERT( sign_is_valid( &R ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Q, &A ) == 0 ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &R, &B ) == 0 ); } @@ -954,6 +1008,7 @@ void mbedtls_mpi_mod_mpi( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( res == div_result ); if( res == 0 ) { + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); } @@ -1007,6 +1062,7 @@ void mbedtls_mpi_exp_mod( int radix_A, char * input_A, int radix_E, TEST_ASSERT( res == div_result ); if( res == 0 ) { + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &X ) == 0 ); } @@ -1066,6 +1122,7 @@ void mbedtls_mpi_inv_mod( int radix_X, char * input_X, int radix_Y, TEST_ASSERT( res == div_result ); if( res == 0 ) { + TEST_ASSERT( sign_is_valid( &Z ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Z, &A ) == 0 ); } @@ -1141,6 +1198,7 @@ void mbedtls_mpi_gen_prime( int bits, int flags, int ref_ret ) TEST_ASSERT( actual_bits >= (size_t) bits ); TEST_ASSERT( actual_bits <= (size_t) bits + 1 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_is_prime_ext( &X, 40, mbedtls_test_rnd_std_rand, @@ -1170,6 +1228,7 @@ void mbedtls_mpi_shift_l( int radix_X, char * input_X, int shift_X, TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_shift_l( &X, shift_X ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); exit: @@ -1187,6 +1246,7 @@ void mbedtls_mpi_shift_r( int radix_X, char * input_X, int shift_X, TEST_ASSERT( mbedtls_test_read_mpi( &X, radix_X, input_X ) == 0 ); TEST_ASSERT( mbedtls_test_read_mpi( &A, radix_A, input_A ) == 0 ); TEST_ASSERT( mbedtls_mpi_shift_r( &X, shift_X ) == 0 ); + TEST_ASSERT( sign_is_valid( &X ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &A ) == 0 ); exit: @@ -1228,6 +1288,7 @@ void mpi_fill_random( int wanted_bytes, int rng_bytes, TEST_ASSERT( mbedtls_mpi_size( &X ) + leading_zeros == (size_t) wanted_bytes ); TEST_ASSERT( (int) bytes_left == rng_bytes - wanted_bytes ); + TEST_ASSERT( sign_is_valid( &X ) ); } exit: @@ -1285,6 +1346,7 @@ void mpi_random_many( int min, data_t *bound_bytes, int iterations ) TEST_EQUAL( 0, mbedtls_mpi_random( &result, min, &upper_bound, mbedtls_test_rnd_std_rand, NULL ) ); + TEST_ASSERT( sign_is_valid( &result ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &result, &upper_bound ) < 0 ); TEST_ASSERT( mbedtls_mpi_cmp_int( &result, min ) >= 0 ); if( full_stats ) @@ -1366,6 +1428,7 @@ void mpi_random_sizes( int min, data_t *bound_bytes, int nlimbs, int before ) bound_bytes->x, bound_bytes->len ) ); TEST_EQUAL( 0, mbedtls_mpi_random( &result, min, &upper_bound, mbedtls_test_rnd_std_rand, NULL ) ); + TEST_ASSERT( sign_is_valid( &result ) ); TEST_ASSERT( mbedtls_mpi_cmp_mpi( &result, &upper_bound ) < 0 ); TEST_ASSERT( mbedtls_mpi_cmp_int( &result, min ) >= 0 );