diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index e6c545e05..93ad7f351 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -1202,9 +1202,8 @@ struct mbedtls_ssl_session uint8_t MBEDTLS_PRIVATE(endpoint); /*!< 0: client, 1: server */ #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) - uint8_t MBEDTLS_PRIVATE(hostname_len); /*!< host_name length */ + size_t MBEDTLS_PRIVATE(hostname_len); /*!< host_name length */ char *MBEDTLS_PRIVATE(hostname); /*!< host name binded with tickets */ - uint8_t hostname_mismatch; /*!< whether new host_name match with saved one */ #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */ #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */ }; diff --git a/library/ssl_client.c b/library/ssl_client.c index 2c5f66494..8080e3ee7 100644 --- a/library/ssl_client.c +++ b/library/ssl_client.c @@ -54,7 +54,6 @@ static int ssl_write_hostname_ext( mbedtls_ssl_context *ssl, { unsigned char *p = buf; size_t hostname_len; - size_t cmp_hostname_len; *olen = 0; @@ -65,25 +64,8 @@ static int ssl_write_hostname_ext( mbedtls_ssl_context *ssl, ( "client hello, adding server name extension: %s", ssl->hostname ) ); - ssl->session_negotiate->hostname_mismatch = 0; hostname_len = strlen( ssl->hostname ); - cmp_hostname_len = hostname_len < ssl->session_negotiate->hostname_len ? - hostname_len : ssl->session_negotiate->hostname_len; - - if( hostname_len != ssl->session_negotiate->hostname_len || - memcmp( ssl->hostname, ssl->session_negotiate->hostname, cmp_hostname_len ) ) - ssl->session_negotiate->hostname_mismatch = 1; - - if( ssl->session_negotiate->hostname == NULL ) - { - ssl->session_negotiate->hostname = mbedtls_calloc( 1, hostname_len ); - if( ssl->session_negotiate->hostname == NULL ) - return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); - memcpy(ssl->session_negotiate->hostname, ssl->hostname, hostname_len); - } - ssl->session_negotiate->hostname_len = hostname_len; - MBEDTLS_SSL_CHK_BUF_PTR( p, end, hostname_len + 9 ); /* @@ -888,6 +870,34 @@ static int ssl_prepare_client_hello( mbedtls_ssl_context *ssl ) } } +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) && \ + defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) + if( ssl->handshake->resume ) + { + if( ssl->hostname != NULL && ssl->session_negotiate->hostname != NULL ) + { + if( strcmp( ssl->hostname, ssl->session_negotiate->hostname ) ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, + ( "hostname mismatch the session ticket, should not resume " ) ); + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } + } + else if( ssl->session_negotiate->hostname != NULL ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, + ( "hostname missed, should not resume " ) ); + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } + } + else + { + mbedtls_ssl_session_set_hostname( ssl->session_negotiate, + ssl->hostname ); + } +#endif /* MBEDTLS_SSL_PROTO_TLS1_3 && + MBEDTLS_SSL_SERVER_NAME_INDICATION */ + return( 0 ); } /* diff --git a/library/ssl_misc.h b/library/ssl_misc.h index afacb76f0..f92a4dbec 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -2201,6 +2201,10 @@ static inline int mbedtls_ssl_tls13_sig_alg_is_supported( return( 1 ); } +#if defined(MBEDTLS_X509_CRT_PARSE_C) +int mbedtls_ssl_session_set_hostname( mbedtls_ssl_session *ssl, + const char *hostname ); +#endif #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */ #if defined(MBEDTLS_SSL_PROTO_TLS1_2) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index abadc80d0..959d01540 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -297,14 +297,16 @@ int mbedtls_ssl_session_copy( mbedtls_ssl_session *dst, } #endif /* MBEDTLS_SSL_SESSION_TICKETS && MBEDTLS_SSL_CLI_C */ -#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) && defined(MBEDTLS_SSL_CLI_C) +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) && \ + defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) && \ + defined(MBEDTLS_SSL_CLI_C) if( src->endpoint == MBEDTLS_SSL_IS_CLIENT && src->hostname != NULL ) { - dst->hostname = mbedtls_calloc( 1, src->hostname_len ); + dst->hostname = mbedtls_calloc( 1, src->hostname_len + 1 ); if( dst->hostname == NULL ) return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); - memcpy( dst->hostname, src->hostname, src->hostname_len ); + strcpy( dst->hostname, src->hostname ); dst->hostname_len = src->hostname_len; } #endif @@ -1958,7 +1960,6 @@ mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_ciphersuite( * uint32 ticket_age_add; * uint8 ticket_flags; * opaque resumption_key<0..255>; - * * select ( endpoint ) { * case client: ClientOnlyData; * case server: uint64 start_time; @@ -1993,7 +1994,7 @@ static int ssl_tls13_session_save( const mbedtls_ssl_session *session, if( session->endpoint == MBEDTLS_SSL_IS_CLIENT ) { #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) - needed += 1 /* hostname_len */ + needed += 2 /* hostname_len */ + session->hostname_len; /* hostname */ #endif @@ -2026,13 +2027,15 @@ static int ssl_tls13_session_save( const mbedtls_ssl_session *session, #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) && defined(MBEDTLS_SSL_CLI_C) if( session->endpoint == MBEDTLS_SSL_IS_CLIENT ) { - p[0] = session->hostname_len; - p++; + MBEDTLS_PUT_UINT16_BE( session->hostname_len, p, 0 ); + p += 2; if ( session->hostname_len > 0 && session->hostname != NULL ) - /* save host name */ - memcpy( p, session->hostname, session->hostname_len ); - p += session->hostname_len; + { + /* save host name */ + memcpy( p, session->hostname, session->hostname_len ); + p += session->hostname_len; + } } #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION && MBEDTLS_SSL_CLI_C */ @@ -2098,19 +2101,20 @@ static int ssl_tls13_session_load( mbedtls_ssl_session *session, if( session->endpoint == MBEDTLS_SSL_IS_CLIENT ) { /* load host name */ - if( end - p < 1 ) + if( end - p < 2 ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - session->hostname_len = p[0]; - p += 1; + session->hostname_len = MBEDTLS_GET_UINT16_BE( p, 0); + p += 2; - if( end - p < session->hostname_len ) + if( end - p < ( long int )session->hostname_len ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); if( session->hostname_len > 0 ) { - session->hostname = mbedtls_calloc( 1, session->hostname_len ); + session->hostname = mbedtls_calloc( 1, session->hostname_len + 1 ); if( session->hostname == NULL ) return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); memcpy( session->hostname, p, session->hostname_len ); + session->hostname[session->hostname_len] = '\0'; p += session->hostname_len; } } @@ -3733,7 +3737,8 @@ void mbedtls_ssl_session_free( mbedtls_ssl_session *session ) mbedtls_free( session->ticket ); #endif -#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) && \ + defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) mbedtls_free( session->hostname ); #endif diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c index abb7a1481..1b827ac60 100644 --- a/library/ssl_tls13_generic.c +++ b/library/ssl_tls13_generic.c @@ -1485,4 +1485,51 @@ int mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange( } #endif /* MBEDTLS_ECDH_C */ +#if defined(MBEDTLS_X509_CRT_PARSE_C) +int mbedtls_ssl_session_set_hostname( mbedtls_ssl_session *ssl, + const char *hostname ) +{ + /* Initialize to suppress unnecessary compiler warning */ + size_t hostname_len = 0; + + /* Check if new hostname is valid before + * making any change to current one */ + if( hostname != NULL ) + { + hostname_len = strlen( hostname ); + + if( hostname_len > MBEDTLS_SSL_MAX_HOST_NAME_LEN ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + } + + /* Now it's clear that we will overwrite the old hostname, + * so we can free it safely */ + + if( ssl->hostname != NULL ) + { + mbedtls_platform_zeroize( ssl->hostname, strlen( ssl->hostname ) ); + mbedtls_free( ssl->hostname ); + } + + /* Passing NULL as hostname shall clear the old one */ + + if( hostname == NULL ) + { + ssl->hostname = NULL; + } + else + { + ssl->hostname = mbedtls_calloc( 1, hostname_len + 1 ); + if( ssl->hostname == NULL ) + return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); + + memcpy( ssl->hostname, hostname, hostname_len ); + + ssl->hostname[hostname_len] = '\0'; + ssl->hostname_len = hostname_len; + } + + return( 0 ); +} +#endif /* MBEDTLS_X509_CRT_PARSE_C */ #endif /* MBEDTLS_SSL_TLS_C && MBEDTLS_SSL_PROTO_TLS1_3 */ diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c index 9102ab40a..be474d473 100644 --- a/programs/ssl/ssl_client2.c +++ b/programs/ssl/ssl_client2.c @@ -3120,7 +3120,7 @@ reconnect: #if defined(MBEDTLS_X509_CRT_PARSE_C) if( ( ret = mbedtls_ssl_set_hostname( &ssl, - opt.reco_server_name ) ) != 0 ) + opt.reco_server_name ) ) != 0 ) { mbedtls_printf( " failed\n ! mbedtls_ssl_set_hostname returned %d\n\n", ret );