Add session ID as explicit parameter to SSL session cache API

Signed-off-by: Hanno Becker <hanno.becker@arm.com>
This commit is contained in:
Hanno Becker 2021-04-15 09:26:17 +01:00
parent a637ff6ddd
commit ccdaf6ed22
5 changed files with 47 additions and 17 deletions

View file

@ -628,9 +628,15 @@ typedef struct mbedtls_ssl_flight_item mbedtls_ssl_flight_item;
#endif #endif
/* TODO: Document */ /* TODO: Document */
typedef int mbedtls_ssl_cache_get_t( void *data, mbedtls_ssl_session *session ); typedef int mbedtls_ssl_cache_get_t( void *data,
unsigned char const *session_id,
size_t session_id_len,
mbedtls_ssl_session *session );
/* TODO: Document */ /* TODO: Document */
typedef int mbedtls_ssl_cache_set_t( void *data, const mbedtls_ssl_session *session ); typedef int mbedtls_ssl_cache_set_t( void *data,
unsigned char const *session_id,
size_t session_id_len,
const mbedtls_ssl_session *session );
#if defined(MBEDTLS_SSL_ASYNC_PRIVATE) #if defined(MBEDTLS_SSL_ASYNC_PRIVATE)
#if defined(MBEDTLS_X509_CRT_PARSE_C) #if defined(MBEDTLS_X509_CRT_PARSE_C)

View file

@ -99,19 +99,32 @@ void mbedtls_ssl_cache_init( mbedtls_ssl_cache_context *cache );
* \brief Cache get callback implementation * \brief Cache get callback implementation
* (Thread-safe if MBEDTLS_THREADING_C is enabled) * (Thread-safe if MBEDTLS_THREADING_C is enabled)
* *
* \param data SSL cache context * \param data The SSL cache context to use.
* \param session session to retrieve entry for * \param session_id The pointer to the buffer holding the session ID
* for the session to load.
* \param session_id_len The length of \p session_id in bytes.
* \param session The address at which to store the session
* associated with \p session_id, if present.
*/ */
int mbedtls_ssl_cache_get( void *data, mbedtls_ssl_session *session ); int mbedtls_ssl_cache_get( void *data,
unsigned char const *session_id,
size_t session_id_len,
mbedtls_ssl_session *session );
/** /**
* \brief Cache set callback implementation * \brief Cache set callback implementation
* (Thread-safe if MBEDTLS_THREADING_C is enabled) * (Thread-safe if MBEDTLS_THREADING_C is enabled)
* *
* \param data SSL cache context * \param data The SSL cache context to use.
* \param session session to store entry for * \param session_id The pointer to the buffer holding the session ID
* associated to \p session.
* \param session_id_len The length of \p session_id in bytes.
* \param session The session to store.
*/ */
int mbedtls_ssl_cache_set( void *data, const mbedtls_ssl_session *session ); int mbedtls_ssl_cache_set( void *data,
unsigned char const *session_id,
size_t session_id_len,
const mbedtls_ssl_session *session );
#if defined(MBEDTLS_HAVE_TIME) #if defined(MBEDTLS_HAVE_TIME)
/** /**

View file

@ -50,7 +50,10 @@ void mbedtls_ssl_cache_init( mbedtls_ssl_cache_context *cache )
#endif #endif
} }
int mbedtls_ssl_cache_get( void *data, mbedtls_ssl_session *session ) int mbedtls_ssl_cache_get( void *data,
unsigned char const *session_id,
size_t session_id_len,
mbedtls_ssl_session *session )
{ {
int ret = 1; int ret = 1;
#if defined(MBEDTLS_HAVE_TIME) #if defined(MBEDTLS_HAVE_TIME)
@ -78,8 +81,8 @@ int mbedtls_ssl_cache_get( void *data, mbedtls_ssl_session *session )
continue; continue;
#endif #endif
if( session->id_len != entry->session.id_len || if( session_id_len != entry->session.id_len ||
memcmp( session->id, entry->session.id, memcmp( session_id, entry->session.id,
entry->session.id_len ) != 0 ) entry->session.id_len ) != 0 )
{ {
continue; continue;
@ -135,7 +138,10 @@ exit:
return( ret ); return( ret );
} }
int mbedtls_ssl_cache_set( void *data, const mbedtls_ssl_session *session ) int mbedtls_ssl_cache_set( void *data,
unsigned char const *session_id,
size_t session_id_len,
const mbedtls_ssl_session *session )
{ {
int ret = 1; int ret = 1;
#if defined(MBEDTLS_HAVE_TIME) #if defined(MBEDTLS_HAVE_TIME)
@ -167,8 +173,11 @@ int mbedtls_ssl_cache_set( void *data, const mbedtls_ssl_session *session )
} }
#endif #endif
if( memcmp( session->id, cur->session.id, cur->session.id_len ) == 0 ) if( session_id_len == cur->session.id_len &&
memcmp( session_id, cur->session.id, cur->session.id_len ) == 0 )
{
break; /* client reconnected, keep timestamp for session id */ break; /* client reconnected, keep timestamp for session id */
}
#if defined(MBEDTLS_HAVE_TIME) #if defined(MBEDTLS_HAVE_TIME)
if( oldest == 0 || cur->timestamp < oldest ) if( oldest == 0 || cur->timestamp < oldest )

View file

@ -2784,10 +2784,9 @@ static void ssl_check_id_based_session_resumption( mbedtls_ssl_context *ssl )
return; return;
#endif #endif
session_tmp.id_len = session->id_len;
memcpy( session_tmp.id, session->id, session->id_len );
ret = ssl->conf->f_get_cache( ssl->conf->p_cache, ret = ssl->conf->f_get_cache( ssl->conf->p_cache,
session->id,
session->id_len,
&session_tmp ); &session_tmp );
if( ret != 0 ) if( ret != 0 )
goto exit; goto exit;

View file

@ -3411,7 +3411,10 @@ void mbedtls_ssl_handshake_wrapup( mbedtls_ssl_context *ssl )
ssl->session->id_len != 0 && ssl->session->id_len != 0 &&
resume == 0 ) resume == 0 )
{ {
if( ssl->conf->f_set_cache( ssl->conf->p_cache, ssl->session ) != 0 ) if( ssl->conf->f_set_cache( ssl->conf->p_cache,
ssl->session->id,
ssl->session->id_len,
ssl->session ) != 0 )
MBEDTLS_SSL_DEBUG_MSG( 1, ( "cache did not store session" ) ); MBEDTLS_SSL_DEBUG_MSG( 1, ( "cache did not store session" ) );
} }