diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index 49b1aca8c..51c1c60d7 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -845,7 +845,9 @@ struct mbedtls_ssl_context size_t in_hslen; /*!< current handshake message length, including the handshake header */ int nb_zero; /*!< # of 0-length encrypted messages */ - int record_read; /*!< record is already present */ + + int keep_current_message; /*!< drop or reuse current message + on next call to record layer? */ /* * Record layer (outgoing data) diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 616cadc08..a2b9f8cfe 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1471,6 +1471,8 @@ static int ssl_parse_server_hello( mbedtls_ssl_context *ssl ) } MBEDTLS_SSL_DEBUG_MSG( 1, ( "non-handshake message during renego" ) ); + + ssl->keep_current_message = 1; return( MBEDTLS_ERR_SSL_WAITING_SERVER_HELLO_RENEGO ); } #endif /* MBEDTLS_SSL_RENEGOTIATION */ @@ -2316,13 +2318,17 @@ static int ssl_parse_server_key_exchange( mbedtls_ssl_context *ssl ) if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_PSK || ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA_PSK ) { - ssl->record_read = 1; + /* Current message is probably either + * CertificateRequest or ServerHelloDone */ + ssl->keep_current_message = 1; goto exit; } - MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) ); + MBEDTLS_SSL_DEBUG_MSG( 1, ( "server key exchange message must " + "not be skipped" ) ); mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL, MBEDTLS_SSL_ALERT_MSG_UNEXPECTED_MESSAGE ); + return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE ); } @@ -2640,38 +2646,32 @@ static int ssl_parse_certificate_request( mbedtls_ssl_context *ssl ) return( 0 ); } - if( ssl->record_read == 0 ) + if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 ) { - if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret ); - return( ret ); - } - - if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE ) - { - MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad certificate request message" ) ); - mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL, - MBEDTLS_SSL_ALERT_MSG_UNEXPECTED_MESSAGE ); - return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE ); - } - - ssl->record_read = 1; + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret ); + return( ret ); } - ssl->client_auth = 0; - ssl->state++; + if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad certificate request message" ) ); + mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL, + MBEDTLS_SSL_ALERT_MSG_UNEXPECTED_MESSAGE ); + return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE ); + } - if( ssl->in_msg[0] == MBEDTLS_SSL_HS_CERTIFICATE_REQUEST ) - ssl->client_auth++; + ssl->state++; + ssl->client_auth = ( ssl->in_msg[0] == MBEDTLS_SSL_HS_CERTIFICATE_REQUEST ); MBEDTLS_SSL_DEBUG_MSG( 3, ( "got %s certificate request", ssl->client_auth ? "a" : "no" ) ); if( ssl->client_auth == 0 ) + { + /* Current message is probably the ServerHelloDone */ + ssl->keep_current_message = 1; goto exit; - - ssl->record_read = 0; + } /* * struct { @@ -2766,21 +2766,17 @@ static int ssl_parse_server_hello_done( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> parse server hello done" ) ); - if( ssl->record_read == 0 ) + if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 ) { - if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret ); - return( ret ); - } - - if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE ) - { - MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server hello done message" ) ); - return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE ); - } + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret ); + return( ret ); + } + + if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server hello done message" ) ); + return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE ); } - ssl->record_read = 0; if( ssl->in_hslen != mbedtls_ssl_hs_hdr_len( ssl ) || ssl->in_msg[0] != MBEDTLS_SSL_HS_SERVER_HELLO_DONE ) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 208500a66..0a29970e8 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3720,27 +3720,35 @@ int mbedtls_ssl_read_record( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> read record" ) ); - do { + if( ssl->keep_current_message == 0 ) + { + do { - if( ( ret = mbedtls_ssl_read_record_layer( ssl ) ) != 0 ) + if( ( ret = mbedtls_ssl_read_record_layer( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, ( "mbedtls_ssl_read_record_layer" ), ret ); + return( ret ); + } + + ret = mbedtls_ssl_handle_message_type( ssl ); + + } while( MBEDTLS_ERR_SSL_NON_FATAL == ret ); + + if( 0 != ret ) { MBEDTLS_SSL_DEBUG_RET( 1, ( "mbedtls_ssl_read_record_layer" ), ret ); return( ret ); } - ret = mbedtls_ssl_handle_message_type( ssl ); - - } while( MBEDTLS_ERR_SSL_NON_FATAL == ret ); - - if( 0 != ret ) - { - MBEDTLS_SSL_DEBUG_RET( 1, ( "mbedtls_ssl_handle_message_type" ), ret ); - return( ret ); + if( ssl->in_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE ) + { + mbedtls_ssl_update_handshake_status( ssl ); + } } - - if( ssl->in_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE ) + else { - mbedtls_ssl_update_handshake_status( ssl ); + MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= reuse previously read message" ) ); + ssl->keep_current_message = 0; } MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= read record" ) ); @@ -5603,7 +5611,8 @@ static int ssl_session_reset_int( mbedtls_ssl_context *ssl, int partial ) ssl->in_hslen = 0; ssl->nb_zero = 0; - ssl->record_read = 0; + + ssl->keep_current_message = 0; ssl->out_msg = ssl->out_buf + 13; ssl->out_msgtype = 0;