Fix psa_pake_get_implicit_key() state & add corresponding tests in ecjpake_rounds()

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
This commit is contained in:
Neil Armstrong 2022-06-15 11:32:11 +02:00
parent ed40782628
commit 1e855601ca
2 changed files with 40 additions and 18 deletions

View file

@ -660,8 +660,8 @@ psa_status_t psa_pake_get_implicit_key(psa_pake_operation_t *operation,
if( operation->alg == 0 ||
operation->state != PSA_PAKE_STATE_READY ||
( operation->input_step != PSA_PAKE_STEP_DERIVE &&
operation->output_step != PSA_PAKE_STEP_DERIVE ) )
operation->input_step != PSA_PAKE_STEP_DERIVE ||
operation->output_step != PSA_PAKE_STEP_DERIVE )
return( PSA_ERROR_BAD_STATE );
#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)

View file

@ -8316,6 +8316,21 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
psa_pake_cs_set_primitive( &cipher_suite, primitive_arg );
psa_pake_cs_set_hash( &cipher_suite, hash_alg );
/* Get shared key */
PSA_ASSERT( psa_key_derivation_setup( &server_derive, derive_alg ) );
PSA_ASSERT( psa_key_derivation_setup( &client_derive, derive_alg ) );
if( PSA_ALG_IS_TLS12_PRF( derive_alg ) ||
PSA_ALG_IS_TLS12_PSK_TO_MS( derive_alg ) )
{
PSA_ASSERT( psa_key_derivation_input_bytes( &server_derive,
PSA_KEY_DERIVATION_INPUT_SEED,
(const uint8_t*) "", 0) );
PSA_ASSERT( psa_key_derivation_input_bytes( &client_derive,
PSA_KEY_DERIVATION_INPUT_SEED,
(const uint8_t*) "", 0) );
}
PSA_ASSERT( psa_pake_setup( &server, &cipher_suite ) );
PSA_ASSERT( psa_pake_setup( &client, &cipher_suite ) );
@ -8325,6 +8340,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
PSA_ASSERT( psa_pake_set_password_key( &server, key ) );
PSA_ASSERT( psa_pake_set_password_key( &client, key ) );
TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
PSA_ERROR_BAD_STATE );
TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
PSA_ERROR_BAD_STATE );
/* Server first round Output */
PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_KEY_SHARE,
buffer0 + buffer0_off,
@ -8389,6 +8409,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
c_x2_pr_off = buffer1_off;
buffer1_off += c_x2_pr_len;
TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
PSA_ERROR_BAD_STATE );
TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
PSA_ERROR_BAD_STATE );
/* Client first round Input */
PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE,
buffer0 + s_g1_off, s_g1_len ) );
@ -8417,6 +8442,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF,
buffer1 + c_x2_pr_off, c_x2_pr_len ) );
TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
PSA_ERROR_BAD_STATE );
TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
PSA_ERROR_BAD_STATE );
/* Server second round Output */
buffer0_off = 0;
@ -8455,6 +8485,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
c_x2s_pr_off = buffer1_off;
buffer1_off += c_x2s_pr_len;
TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
PSA_ERROR_BAD_STATE );
TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
PSA_ERROR_BAD_STATE );
/* Client second round Input */
PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE,
buffer0 + s_a_off, s_a_len ) );
@ -8463,6 +8498,9 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PROOF,
buffer0 + s_x2s_pr_off, s_x2s_pr_len ) );
TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
PSA_ERROR_BAD_STATE );
/* Server second round Input */
PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_KEY_SHARE,
buffer1 + c_a_off, c_a_len ) );
@ -8471,22 +8509,6 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF,
buffer1 + c_x2s_pr_off, c_x2s_pr_len ) );
/* Get shared key */
PSA_ASSERT( psa_key_derivation_setup( &server_derive, derive_alg ) );
PSA_ASSERT( psa_key_derivation_setup( &client_derive, derive_alg ) );
if( PSA_ALG_IS_TLS12_PRF( derive_alg ) ||
PSA_ALG_IS_TLS12_PSK_TO_MS( derive_alg ) )
{
PSA_ASSERT( psa_key_derivation_input_bytes( &server_derive,
PSA_KEY_DERIVATION_INPUT_SEED,
(const uint8_t*) "", 0) );
PSA_ASSERT( psa_key_derivation_input_bytes( &client_derive,
PSA_KEY_DERIVATION_INPUT_SEED,
(const uint8_t*) "", 0) );
}
PSA_ASSERT( psa_pake_get_implicit_key( &server, &server_derive ) );
PSA_ASSERT( psa_pake_get_implicit_key( &client, &client_derive ) );