diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 2e1edea93..dfcc085a0 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -6911,9 +6911,9 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, size_t iv_copy_len; size_t keylen; const mbedtls_ssl_ciphersuite_t *ciphersuite_info; - const mbedtls_cipher_info_t *cipher_info; mbedtls_ssl_mode_t ssl_mode; #if !defined(MBEDTLS_USE_PSA_CRYPTO) + const mbedtls_cipher_info_t *cipher_info; const mbedtls_md_info_t *md_info; #endif /* !MBEDTLS_USE_PSA_CRYPTO */ @@ -6974,6 +6974,22 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, #endif /* MBEDTLS_SSL_ENCRYPT_THEN_MAC */ ciphersuite_info ); + if( ssl_mode == MBEDTLS_SSL_MODE_AEAD ) + transform->taglen = + ciphersuite_info->flags & MBEDTLS_CIPHERSUITE_SHORT_TAG ? 8 : 16; + +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ( status = mbedtls_ssl_cipher_to_psa( ciphersuite_info->cipher, + transform->taglen, + &alg, + &key_type, + &key_bits ) ) != PSA_SUCCESS ) + { + ret = psa_ssl_status_to_mbedtls( status ); + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_cipher_to_psa", ret ); + goto end; + } +#else cipher_info = mbedtls_cipher_info_from_type( ciphersuite_info->cipher ); if( cipher_info == NULL ) { @@ -6981,6 +6997,7 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, ciphersuite_info->cipher ) ); return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_USE_PSA_CRYPTO) mac_alg = mbedtls_psa_translate_md( ciphersuite_info->mac ); @@ -7040,7 +7057,11 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, * Determine the appropriate key, IV and MAC length. */ +#if defined(MBEDTLS_USE_PSA_CRYPTO) + keylen = PSA_BITS_TO_BYTES(key_bits); +#else keylen = mbedtls_cipher_info_get_key_bitlen( cipher_info ) / 8; +#endif #if defined(MBEDTLS_GCM_C) || \ defined(MBEDTLS_CCM_C) || \ @@ -7051,8 +7072,6 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, transform->maclen = 0; mac_key_len = 0; - transform->taglen = - ciphersuite_info->flags & MBEDTLS_CIPHERSUITE_SHORT_TAG ? 8 : 16; /* All modes haves 96-bit IVs, but the length of the static parts vary * with mode and version: @@ -7063,7 +7082,11 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, * sequence number). */ transform->ivlen = 12; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( key_type == PSA_KEY_TYPE_CHACHA20 ) +#else if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CHACHAPOLY ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ transform->fixed_ivlen = 12; else transform->fixed_ivlen = 4; @@ -7079,6 +7102,12 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, ssl_mode == MBEDTLS_SSL_MODE_CBC || ssl_mode == MBEDTLS_SSL_MODE_CBC_ETM ) { +#if defined(MBEDTLS_USE_PSA_CRYPTO) + size_t block_size = PSA_BLOCK_CIPHER_BLOCK_MAX_SIZE; +#else + size_t block_size = cipher_info->block_size; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ + #if defined(MBEDTLS_USE_PSA_CRYPTO) /* Get MAC length */ mac_key_len = PSA_HASH_LENGTH(mac_alg); @@ -7097,7 +7126,14 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, transform->maclen = mac_key_len; /* IV length */ +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ssl_mode == MBEDTLS_SSL_MODE_STREAM ) + transform->ivlen = 0; + else + transform->ivlen = PSA_CIPHER_IV_LENGTH( key_type, alg ); +#else transform->ivlen = cipher_info->iv_size; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ /* Minimum length */ if( ssl_mode == MBEDTLS_SSL_MODE_STREAM ) @@ -7114,14 +7150,14 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, if( ssl_mode == MBEDTLS_SSL_MODE_CBC_ETM ) { transform->minlen = transform->maclen - + cipher_info->block_size; + + block_size; } else #endif { transform->minlen = transform->maclen - + cipher_info->block_size - - transform->maclen % cipher_info->block_size; + + block_size + - transform->maclen % block_size; } if( tls_version == MBEDTLS_SSL_VERSION_TLS1_2 ) @@ -7203,17 +7239,6 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, } #if defined(MBEDTLS_USE_PSA_CRYPTO) - if( ( status = mbedtls_ssl_cipher_to_psa( cipher_info->type, - transform->taglen, - &alg, - &key_type, - &key_bits ) ) != PSA_SUCCESS ) - { - ret = psa_ssl_status_to_mbedtls( status ); - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_cipher_to_psa", ret ); - goto end; - } - transform->psa_alg = alg; if ( alg != MBEDTLS_SSL_NULL_CIPHER )