diff --git a/library/psa_crypto.c b/library/psa_crypto.c index b4fad33d3..4f3d774af 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -7241,8 +7241,7 @@ psa_status_t psa_pake_setup( return PSA_ERROR_BAD_STATE; } - if (cipher_suite == NULL || - PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 || + if (PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 || PSA_ALG_IS_HASH(cipher_suite->hash) == 0) { return PSA_ERROR_INVALID_ARGUMENT; } @@ -7436,17 +7435,12 @@ static psa_pake_driver_step_t convert_jpake_computation_stage_to_driver_step( static psa_status_t psa_pake_complete_inputs( psa_pake_operation_t *operation) { - psa_jpake_computation_stage_t *computation_stage = - &operation->computation_stage.jpake; psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; - uint8_t *password = operation->data.inputs.password; - size_t password_len = operation->data.inputs.password_len; /* Create copy of the inputs on stack as inputs share memory with the driver context which will be setup by the driver. */ psa_crypto_driver_pake_inputs_t inputs = operation->data.inputs; - if (operation->alg == PSA_ALG_NONE || - operation->data.inputs.password_len == 0 || + if (operation->data.inputs.password_len == 0 || operation->data.inputs.role == PSA_PAKE_ROLE_NONE) { return PSA_ERROR_BAD_STATE; } @@ -7457,12 +7451,14 @@ static psa_status_t psa_pake_complete_inputs( status = psa_driver_wrapper_pake_setup(operation, &inputs); /* Driver is responsible for creating its own copy of the password. */ - mbedtls_platform_zeroize(password, password_len); - mbedtls_free(password); + mbedtls_platform_zeroize(inputs.password, inputs.password_len); + mbedtls_free(inputs.password); if (status == PSA_SUCCESS) { operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION; if (operation->alg == PSA_ALG_JPAKE) { + psa_jpake_computation_stage_t *computation_stage = + &operation->computation_stage.jpake; computation_stage->state = PSA_PAKE_STATE_READY; computation_stage->sequence = PSA_PAKE_SEQ_INVALID; computation_stage->input_step = PSA_PAKE_STEP_X1_X2; @@ -7576,6 +7572,7 @@ psa_status_t psa_pake_output( size_t *output_length) { psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + *output_length = 0; if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) { status = psa_pake_complete_inputs(operation); @@ -7588,11 +7585,7 @@ psa_status_t psa_pake_output( return PSA_ERROR_BAD_STATE; } - if (operation->id == 0) { - return PSA_ERROR_BAD_STATE; - } - - if (output == NULL || output_size == 0) { + if (output_size == 0) { return PSA_ERROR_INVALID_ARGUMENT; } @@ -7750,11 +7743,7 @@ psa_status_t psa_pake_input( return PSA_ERROR_BAD_STATE; } - if (operation->id == 0) { - return PSA_ERROR_BAD_STATE; - } - - if (input == NULL || input_length == 0) { + if (input_length == 0) { return PSA_ERROR_INVALID_ARGUMENT; } @@ -7797,13 +7786,13 @@ psa_status_t psa_pake_get_implicit_key( psa_pake_operation_t *operation, psa_key_derivation_operation_t *output) { - psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE]; size_t shared_key_len = 0; psa_jpake_computation_stage_t *computation_stage = &operation->computation_stage.jpake; - if (operation->id == 0) { + if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) { return PSA_ERROR_BAD_STATE; } diff --git a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja index d52ed5993..cf08794c6 100644 --- a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja +++ b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja @@ -2816,7 +2816,7 @@ psa_status_t psa_driver_wrapper_pake_setup( psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; psa_key_location_t location = - PSA_KEY_LIFETIME_GET_LOCATION( inputs->attributes.core.lifetime ); + PSA_KEY_LIFETIME_GET_LOCATION( psa_get_key_lifetime( &inputs->attributes ) ); switch( location ) { diff --git a/tests/suites/test_suite_psa_crypto_pake.function b/tests/suites/test_suite_psa_crypto_pake.function index 5af41f75f..d77dfdc8e 100644 --- a/tests/suites/test_suite_psa_crypto_pake.function +++ b/tests/suites/test_suite_psa_crypto_pake.function @@ -590,10 +590,10 @@ void ecjpake_setup(int alg_arg, int key_type_pw_arg, int key_usage_pw_arg, TEST_EQUAL(psa_pake_set_role(&operation, role), expected_error); TEST_EQUAL(psa_pake_output(&operation, PSA_PAKE_STEP_KEY_SHARE, - NULL, 0, NULL), + output_buffer, 0, &output_len), expected_error); TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE, - NULL, 0), + output_buffer, 0), expected_error); TEST_EQUAL(psa_pake_get_implicit_key(&operation, &key_derivation), expected_error); @@ -633,7 +633,8 @@ void ecjpake_setup(int alg_arg, int key_type_pw_arg, int key_usage_pw_arg, if (test_input) { SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation, - PSA_PAKE_STEP_ZK_PROOF, NULL, 0), + PSA_PAKE_STEP_ZK_PROOF, + output_buffer, 0), ERR_INJECT_EMPTY_IO_BUFFER); SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation, @@ -665,7 +666,8 @@ void ecjpake_setup(int alg_arg, int key_type_pw_arg, int key_usage_pw_arg, } else { SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation, PSA_PAKE_STEP_ZK_PROOF, - NULL, 0, NULL), + output_buffer, 0, + &output_len), ERR_INJECT_EMPTY_IO_BUFFER); SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation,