diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c index 486632ee8..1a4571cad 100644 --- a/library/ssl_tls12_server.c +++ b/library/ssl_tls12_server.c @@ -3068,7 +3068,8 @@ curve_matching_done: #if defined(MBEDTLS_USE_PSA_CRYPTO) if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA || - ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA ) + ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA || + ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK ) { psa_status_t status = PSA_ERROR_GENERIC_ERROR; psa_key_attributes_t key_attributes; @@ -4037,6 +4038,96 @@ static int ssl_parse_client_key_exchange( mbedtls_ssl_context *ssl ) } else #endif /* MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */ +#if defined(MBEDTLS_USE_PSA_CRYPTO) && \ + defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) + if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK ) + { + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t destruction_status = PSA_ERROR_CORRUPTION_DETECTED; + uint8_t ecpoint_len; + + mbedtls_ssl_handshake_params *handshake = ssl->handshake; + + if( ( ret = ssl_parse_client_psk_identity( ssl, &p, end ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, ( "ssl_parse_client_psk_identity" ), ret ); + psa_destroy_key( handshake->ecdh_psa_privkey ); + handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT; + return( ret ); + } + + /* Keep a copy of the peer's public key */ + ecpoint_len = *(p++); + if( (size_t)( end - *p ) < ecpoint_len ) { + psa_destroy_key( handshake->ecdh_psa_privkey ); + handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT; + return( MBEDTLS_ERR_SSL_DECODE_ERROR ); + } + + if( ecpoint_len > sizeof( handshake->ecdh_psa_peerkey ) ) { + psa_destroy_key( handshake->ecdh_psa_privkey ); + handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT; + return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); + } + + memcpy( handshake->ecdh_psa_peerkey, p, ecpoint_len ); + handshake->ecdh_psa_peerkey_len = ecpoint_len; + p += ecpoint_len; + + /* The ECDH secret is the premaster secret used for key derivation. */ + unsigned char *psm = ssl->handshake->premaster; + unsigned char *psm_end = psm + sizeof( ssl->handshake->premaster ); + size_t zlen; + + /* Compute ECDH shared secret. */ + status = psa_raw_key_agreement( PSA_ALG_ECDH, + handshake->ecdh_psa_privkey, + handshake->ecdh_psa_peerkey, + handshake->ecdh_psa_peerkey_len, + psm + 2, + psm_end - ( psm + 2 ), + &zlen ); + + destruction_status = psa_destroy_key( handshake->ecdh_psa_privkey ); + handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT; + + if( status != PSA_SUCCESS || destruction_status != PSA_SUCCESS ) + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + + MBEDTLS_PUT_UINT16_BE( zlen, psm, 0 ); + psm += 2 + zlen; + + /* opaque psk<0..2^16-1>; */ + if( psm_end - psm < 2 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + const unsigned char *psk = NULL; + size_t psk_len = 0; + + if( mbedtls_ssl_get_psk( ssl, &psk, &psk_len ) + == MBEDTLS_ERR_SSL_PRIVATE_KEY_REQUIRED ) + { + /* + * This should never happen because the existence of a PSK is always + * checked before calling this function + */ + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } + + MBEDTLS_PUT_UINT16_BE( psk_len, psm, 0 ); + psm += 2; + + if( psm_end < psm || (size_t)( psm_end - psm ) < psk_len ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + memcpy( psm, psk, psk_len ); + psm += psk_len; + + ssl->handshake->pmslen = psm - ssl->handshake->premaster; + } + else +#endif /* MBEDTLS_USE_PSA_CRYPTO && + MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED */ #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK ) {