diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 6d9101147..d05579461 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -440,12 +440,12 @@ static void ssl_calc_verify_tls_sha384( const mbedtls_ssl_context *, unsigned ch static void ssl_calc_finished_tls_sha384( mbedtls_ssl_context *, unsigned char *, int ); #endif /* MBEDTLS_SHA384_C */ -static size_t ssl_session_save_tls12( const mbedtls_ssl_session *session, +static size_t ssl_tls12_session_save( const mbedtls_ssl_session *session, unsigned char *buf, size_t buf_len ); MBEDTLS_CHECK_RETURN_CRITICAL -static int ssl_session_load_tls12( mbedtls_ssl_session *session, +static int ssl_tls12_session_load( mbedtls_ssl_session *session, const unsigned char *buf, size_t len ); #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ @@ -1888,7 +1888,170 @@ mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_ciphersuite( #if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3) #if defined(MBEDTLS_SSL_PROTO_TLS1_3) -static size_t ssl_session_save_tls13( const mbedtls_ssl_session *session, +/* Serialization of TLS 1.3 sessions: + * + * struct { + * uint64 ticket_received; + * uint32 ticket_lifetime; + * opaque ticket<0..2^16>; + * } ClientOnlyData; + * + * struct { + * uint8 endpoint; + * uint8 ciphersuite[2]; + * uint32 ticket_age_add; + * uint8 ticket_flags; + * opaque resumption_key<0..255>; + * select ( endpoint ) { + * case client: ClientOnlyData; + * case server: uint64 start_time; + * }; + * } serialized_session_tls13; + * + */ +#if defined(MBEDTLS_SSL_SESSION_TICKETS) +static size_t ssl_tls13_session_save( const mbedtls_ssl_session *session, + unsigned char *buf, + size_t buf_len ) +{ + unsigned char *p = buf; + size_t needed = 1 /* endpoint */ + + 2 /* ciphersuite */ + + 4 /* ticket_age_add */ + + 2 /* key_len */ + + session->key_len; /* key */ + +#if defined(MBEDTLS_HAVE_TIME) + needed += 8; /* start_time or ticket_received */ +#endif + +#if defined(MBEDTLS_SSL_CLI_C) + if( session->endpoint == MBEDTLS_SSL_IS_CLIENT ) + { + needed += 4 /* ticket_lifetime */ + + 2 /* ticket_len */ + + session->ticket_len; /* ticket */ + } +#endif /* MBEDTLS_SSL_CLI_C */ + + if( needed > buf_len ) + return( needed ); + + p[0] = session->endpoint; + MBEDTLS_PUT_UINT16_BE( session->ciphersuite, p, 1 ); + MBEDTLS_PUT_UINT32_BE( session->ticket_age_add, p, 3 ); + p[7] = session->ticket_flags; + + /* save resumption_key */ + p[8] = session->key_len; + p += 9; + memcpy( p, session->key, session->key_len ); + p += session->key_len; + +#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C) + if( session->endpoint == MBEDTLS_SSL_IS_SERVER ) + { + MBEDTLS_PUT_UINT64_BE( (uint64_t) session->start, p, 0 ); + p += 8; + } +#endif /* MBEDTLS_HAVE_TIME */ + +#if defined(MBEDTLS_SSL_CLI_C) + if( session->endpoint == MBEDTLS_SSL_IS_CLIENT ) + { +#if defined(MBEDTLS_HAVE_TIME) + MBEDTLS_PUT_UINT64_BE( (uint64_t) session->ticket_received, p, 0 ); + p += 8; +#endif + MBEDTLS_PUT_UINT32_BE( session->ticket_lifetime, p, 0 ); + p += 4; + + MBEDTLS_PUT_UINT16_BE( session->ticket_len, p, 0 ); + p += 2; + if( session->ticket_len > 0 ) + { + memcpy( p, session->ticket, session->ticket_len ); + p += session->ticket_len; + } + } +#endif /* MBEDTLS_SSL_CLI_C */ + return( needed ); +} + +MBEDTLS_CHECK_RETURN_CRITICAL +static int ssl_tls13_session_load( mbedtls_ssl_session *session, + const unsigned char *buf, + size_t len ) +{ + const unsigned char *p = buf; + const unsigned char *end = buf + len; + + if( end - p < 9 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + session->endpoint = p[0]; + session->ciphersuite = MBEDTLS_GET_UINT16_BE( p, 1 ); + session->ticket_age_add = MBEDTLS_GET_UINT32_BE( p, 3 ); + session->ticket_flags = p[7]; + + /* load resumption_key */ + session->key_len = p[8]; + p += 9; + + if( end - p < session->key_len ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( sizeof( session->key ) < session->key_len) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + memcpy( session->key, p, session->key_len ); + p += session->key_len; + +#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C) + if( session->endpoint == MBEDTLS_SSL_IS_SERVER ) + { + if( end - p < 8 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + session->start = MBEDTLS_GET_UINT64_BE( p, 0 ); + p += 8; + } +#endif /* MBEDTLS_HAVE_TIME */ + +#if defined(MBEDTLS_SSL_CLI_C) + if( session->endpoint == MBEDTLS_SSL_IS_CLIENT ) + { +#if defined(MBEDTLS_HAVE_TIME) + if( end - p < 8 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + session->ticket_received = MBEDTLS_GET_UINT64_BE( p, 0 ); + p += 8; +#endif + if( end - p < 4 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + session->ticket_lifetime = MBEDTLS_GET_UINT32_BE( p, 0 ); + p += 4; + + if( end - p < 2 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + session->ticket_len = MBEDTLS_GET_UINT16_BE( p, 0 ); + p += 2; + + if( end - p < ( long int )session->ticket_len ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + if( session->ticket_len > 0 ) + { + session->ticket = mbedtls_calloc( 1, session->ticket_len ); + if( session->ticket == NULL ) + return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); + memcpy( session->ticket, p, session->ticket_len ); + p += session->ticket_len; + } + } +#endif /* MBEDTLS_SSL_CLI_C */ + + return( 0 ); + +} +#else /* MBEDTLS_SSL_SESSION_TICKETS */ +static size_t ssl_tls13_session_save( const mbedtls_ssl_session *session, unsigned char *buf, size_t buf_len ) { @@ -1897,6 +2060,17 @@ static size_t ssl_session_save_tls13( const mbedtls_ssl_session *session, ((void) buf_len); return( 0 ); } + +static int ssl_tls13_session_load( const mbedtls_ssl_session *session, + unsigned char *buf, + size_t buf_len ) +{ + ((void) session); + ((void) buf); + ((void) buf_len); + return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE ); +} +#endif /* !MBEDTLS_SSL_SESSION_TICKETS */ #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */ psa_status_t mbedtls_ssl_cipher_to_psa( mbedtls_cipher_type_t mbedtls_cipher_type, @@ -2827,12 +3001,14 @@ static int ssl_session_save( const mbedtls_ssl_session *session, size_t used = 0; size_t remaining_len; + if( session == NULL ) + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + if( !omit_header ) { /* * Add Mbed TLS version identifier */ - used += sizeof( ssl_serialized_session_header ); if( used <= buf_len ) @@ -2858,18 +3034,14 @@ static int ssl_session_save( const mbedtls_ssl_session *session, { #if defined(MBEDTLS_SSL_PROTO_TLS1_2) case MBEDTLS_SSL_VERSION_TLS1_2: - { - used += ssl_session_save_tls12( session, p, remaining_len ); + used += ssl_tls12_session_save( session, p, remaining_len ); break; - } #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ #if defined(MBEDTLS_SSL_PROTO_TLS1_3) case MBEDTLS_SSL_VERSION_TLS1_3: - { - used += ssl_session_save_tls13( session, p, remaining_len ); + used += ssl_tls13_session_save( session, p, remaining_len ); break; - } #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */ default: @@ -2908,6 +3080,11 @@ static int ssl_session_load( mbedtls_ssl_session *session, { const unsigned char *p = buf; const unsigned char * const end = buf + len; + size_t remaining_len; + + + if( session == NULL ) + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); if( !omit_header ) { @@ -2934,16 +3111,19 @@ static int ssl_session_load( mbedtls_ssl_session *session, session->tls_version = 0x0300 | *p++; /* Dispatch according to TLS version. */ + remaining_len = ( end - p ); switch( session->tls_version ) { #if defined(MBEDTLS_SSL_PROTO_TLS1_2) case MBEDTLS_SSL_VERSION_TLS1_2: - { - size_t remaining_len = ( end - p ); - return( ssl_session_load_tls12( session, p, remaining_len ) ); - } + return( ssl_tls12_session_load( session, p, remaining_len ) ); #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) + case MBEDTLS_SSL_VERSION_TLS1_3: + return( ssl_tls13_session_load( session, p, remaining_len ) ); +#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */ + default: return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } @@ -7826,7 +8006,7 @@ unsigned int mbedtls_ssl_tls12_get_preferred_hash_for_sig_alg( * } serialized_session_tls12; * */ -static size_t ssl_session_save_tls12( const mbedtls_ssl_session *session, +static size_t ssl_tls12_session_save( const mbedtls_ssl_session *session, unsigned char *buf, size_t buf_len ) { @@ -7978,7 +8158,7 @@ static size_t ssl_session_save_tls12( const mbedtls_ssl_session *session, } MBEDTLS_CHECK_RETURN_CRITICAL -static int ssl_session_load_tls12( mbedtls_ssl_session *session, +static int ssl_tls12_session_load( mbedtls_ssl_session *session, const unsigned char *buf, size_t len ) {