diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 5c05f7928..cbdc91293 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -5148,7 +5148,7 @@ static psa_status_t psa_key_derivation_setup_kdf( return( PSA_ERROR_NOT_SUPPORTED ); /* All currently supported key derivation algorithms (apart from - * ecjpake to pms are based on a hash algorithm. */ + * ecjpake to pms) are based on a hash algorithm. */ psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( kdf_alg ); size_t hash_size = PSA_HASH_LENGTH( hash_alg ); if( !PSA_ALG_IS_TLS12_ECJPAKE_TO_PMS( kdf_alg ) ) @@ -5570,10 +5570,12 @@ static psa_status_t psa_tls12_prf_psk_to_ms_input( #if defined(MBEDTLS_PSA_BUILTIN_ALG_TLS12_ECJPAKE_TO_PMS) static psa_status_t psa_tls12_ecjpake_to_pms_input( psa_tls12_ecjpake_to_pms_t *ecjpake, + psa_key_derivation_step_t step, const uint8_t *data, size_t data_length ) { - if( data_length != PSA_TLS12_ECJPAKE_TO_PMS_INPUT_SIZE ) + if( data_length != PSA_TLS12_ECJPAKE_TO_PMS_INPUT_SIZE || + step != PSA_KEY_DERIVATION_INPUT_SECRET ) return( PSA_ERROR_INVALID_ARGUMENT ); /* Check if the passed point is in an uncompressed form */ @@ -5668,7 +5670,7 @@ static psa_status_t psa_key_derivation_input_internal( if( PSA_ALG_IS_TLS12_ECJPAKE_TO_PMS( kdf_alg ) ) { status = psa_tls12_ecjpake_to_pms_input( - &operation->ctx.tls12_ecjpake_to_pms, data, data_length ); + &operation->ctx.tls12_ecjpake_to_pms, step, data, data_length ); } else #endif /* MBEDTLS_PSA_BUILTIN_ALG_TLS12_ECJPAKE_TO_PMS */