diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c index c757c6cba..bca7f9e29 100644 --- a/library/ssl_tls13_client.c +++ b/library/ssl_tls13_client.c @@ -838,36 +838,20 @@ static int ssl_tls13_write_binder( mbedtls_ssl_context *ssl, unsigned char *buf, unsigned char *end, int psk_type, + psa_algorithm_t hash_alg, + const unsigned char *psk, + size_t psk_len, size_t *out_len ) { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; unsigned char *p = buf; - const unsigned char *psk; - psa_algorithm_t psa_alg = PSA_ALG_NONE; - size_t psk_len; unsigned char binder_len; unsigned char transcript[MBEDTLS_MD_MAX_SIZE]; size_t transcript_len = 0; *out_len = 0; - switch( psk_type ) - { -#if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_HAVE_TIME) - case MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION: - if( ssl_tls13_ticket_get_psk( ssl, &psa_alg, &psk, &psk_len ) != 0 ) - return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); - break; -#endif /* MBEDTLS_SSL_SESSION_TICKETS && MBEDTLS_HAVE_TIME*/ - case MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL: - if( ssl_tls13_psk_get_psk( ssl, &psa_alg, &psk, &psk_len ) != 0 ) - return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); - break; - default: - return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); - } - - binder_len = PSA_HASH_LENGTH( psa_alg ); + binder_len = PSA_HASH_LENGTH( hash_alg ); if( binder_len == 0 ) { MBEDTLS_SSL_DEBUG_MSG( 3, ( "should never happen" ) ); @@ -884,14 +868,12 @@ static int ssl_tls13_write_binder( mbedtls_ssl_context *ssl, /* Get current state of handshake transcript. */ ret = mbedtls_ssl_get_handshake_transcript( - ssl, mbedtls_hash_info_md_from_psa( psa_alg ), + ssl, mbedtls_hash_info_md_from_psa( hash_alg ), transcript, MBEDTLS_MD_MAX_SIZE, &transcript_len ); if( ret != 0 ) return( ret ); - - - ret = mbedtls_ssl_tls13_create_psk_binder( ssl, psa_alg, + ret = mbedtls_ssl_tls13_create_psk_binder( ssl, hash_alg, psk, psk_len, psk_type, transcript, p + 1 ); if( ret != 0 ) @@ -1043,6 +1025,10 @@ int mbedtls_ssl_tls13_write_binders_of_pre_shared_key_ext( { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; unsigned char *p = buf; + psa_algorithm_t hash_alg = PSA_ALG_NONE; + const unsigned char *psk; + size_t psk_len; + size_t output_len; /* Check if we have space to write binders_len. * - binders_len (2 bytes) @@ -1050,22 +1036,26 @@ int mbedtls_ssl_tls13_write_binders_of_pre_shared_key_ext( MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 ); p += 2; - if( ssl_tls13_has_configured_ticket( ssl ) ) +#if defined(MBEDTLS_SSL_SESSION_TICKETS) + if( ssl_tls13_ticket_get_psk( ssl, &hash_alg, &psk, &psk_len ) == 0 ) { - size_t output_len; + ret = ssl_tls13_write_binder( ssl, p, end, MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION, + hash_alg, psk, psk_len, &output_len ); if( ret != 0 ) return( ret ); p += output_len; } +#endif /* MBEDTLS_SSL_SESSION_TICKETS */ - if( ssl_tls13_has_configured_psk( ssl ) ) + if( ssl_tls13_psk_get_psk( ssl, &hash_alg, &psk, &psk_len ) == 0 ) { - size_t output_len; + ret = ssl_tls13_write_binder( ssl, p, end, MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL, + hash_alg, psk, psk_len, &output_len ); if( ret != 0 ) return( ret );