Add explicit counter in DTLS record header
This commit is contained in:
parent
507e1e410a
commit
0619348288
3 changed files with 86 additions and 52 deletions
|
@ -1827,6 +1827,15 @@ void ssl_write_version( int major, int minor, int transport,
|
|||
void ssl_read_version( int *major, int *minor, int transport,
|
||||
const unsigned char ver[2] );
|
||||
|
||||
static inline size_t ssl_hdr_len( const ssl_context *ssl )
|
||||
{
|
||||
#if defined(POLARSSL_SSL_PROTO_DTLS)
|
||||
if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
|
||||
return( 13 );
|
||||
#endif
|
||||
return( 5 );
|
||||
}
|
||||
|
||||
/* constant-time buffer comparison */
|
||||
static inline int safer_memcmp( const void *a, const void *b, size_t n )
|
||||
{
|
||||
|
|
|
@ -1134,7 +1134,7 @@ static int ssl_parse_client_hello( ssl_context *ssl )
|
|||
SSL_DEBUG_MSG( 2, ( "=> parse client hello" ) );
|
||||
|
||||
if( ssl->renegotiation == SSL_INITIAL_HANDSHAKE &&
|
||||
( ret = ssl_fetch_input( ssl, 5 ) ) != 0 )
|
||||
( ret = ssl_fetch_input( ssl, ssl_hdr_len( ssl ) ) ) != 0 )
|
||||
{
|
||||
SSL_DEBUG_RET( 1, "ssl_fetch_input", ret );
|
||||
return( ret );
|
||||
|
@ -1147,7 +1147,7 @@ static int ssl_parse_client_hello( ssl_context *ssl )
|
|||
return ssl_parse_client_hello_v2( ssl );
|
||||
#endif
|
||||
|
||||
SSL_DEBUG_BUF( 4, "record header", buf, 5 ); // TODO: 13 for DTLS
|
||||
SSL_DEBUG_BUF( 4, "record header", buf, ssl_hdr_len( ssl ) );
|
||||
|
||||
SSL_DEBUG_MSG( 3, ( "client hello v3, message type: %d",
|
||||
buf[0] ) );
|
||||
|
@ -1191,7 +1191,7 @@ static int ssl_parse_client_hello( ssl_context *ssl )
|
|||
}
|
||||
|
||||
if( ssl->renegotiation == SSL_INITIAL_HANDSHAKE &&
|
||||
( ret = ssl_fetch_input( ssl, 5 + n ) ) != 0 )
|
||||
( ret = ssl_fetch_input( ssl, ssl_hdr_len( ssl ) + n ) ) != 0 )
|
||||
{
|
||||
SSL_DEBUG_RET( 1, "ssl_fetch_input", ret );
|
||||
return( ret );
|
||||
|
@ -1199,7 +1199,7 @@ static int ssl_parse_client_hello( ssl_context *ssl )
|
|||
|
||||
buf = ssl->in_msg;
|
||||
if( !ssl->renegotiation )
|
||||
n = ssl->in_left - 5;
|
||||
n = ssl->in_left - ssl_hdr_len( ssl );
|
||||
else
|
||||
n = ssl->in_msglen;
|
||||
|
||||
|
|
|
@ -1284,18 +1284,6 @@ static int ssl_encrypt_buf( ssl_context *ssl )
|
|||
return( POLARSSL_ERR_SSL_INTERNAL_ERROR );
|
||||
}
|
||||
|
||||
// TODO: adapt for DTLS (start from i = 6)
|
||||
for( i = 8; i > 0; i-- )
|
||||
if( ++ssl->out_ctr[i - 1] != 0 )
|
||||
break;
|
||||
|
||||
/* The loops goes to its end iff the counter is wrapping */
|
||||
if( i == 0 )
|
||||
{
|
||||
SSL_DEBUG_MSG( 1, ( "outgoing message counter would wrap" ) );
|
||||
return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
|
||||
}
|
||||
|
||||
SSL_DEBUG_MSG( 2, ( "<= encrypt buf" ) );
|
||||
|
||||
return( 0 );
|
||||
|
@ -1702,16 +1690,19 @@ static int ssl_decrypt_buf( ssl_context *ssl )
|
|||
else
|
||||
ssl->nb_zero = 0;
|
||||
|
||||
// TODO: DTLS: i = 6
|
||||
for( i = 8; i > 0; i-- )
|
||||
if( ++ssl->in_ctr[i - 1] != 0 )
|
||||
break;
|
||||
|
||||
/* The loops goes to its end iff the counter is wrapping */
|
||||
if( i == 0 )
|
||||
/* For DTLS we don't maintain our own incoming counter (for now) */
|
||||
if( ssl->transport == SSL_TRANSPORT_STREAM )
|
||||
{
|
||||
SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) );
|
||||
return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
|
||||
for( i = 8; i > 0; i-- )
|
||||
if( ++ssl->in_ctr[i - 1] != 0 )
|
||||
break;
|
||||
|
||||
/* The loop goes to its end iff the counter is wrapping */
|
||||
if( i == 0 )
|
||||
{
|
||||
SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) );
|
||||
return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
|
||||
}
|
||||
}
|
||||
|
||||
SSL_DEBUG_MSG( 2, ( "<= decrypt buf" ) );
|
||||
|
@ -1860,17 +1851,25 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want )
|
|||
*/
|
||||
int ssl_flush_output( ssl_context *ssl )
|
||||
{
|
||||
int ret;
|
||||
int ret, i;
|
||||
unsigned char *buf;
|
||||
|
||||
SSL_DEBUG_MSG( 2, ( "=> flush output" ) );
|
||||
|
||||
/* Avoid incrementing counter if data is flushed */
|
||||
if( ssl->out_left == 0 )
|
||||
{
|
||||
SSL_DEBUG_MSG( 2, ( "<= flush output" ) );
|
||||
return( 0 );
|
||||
}
|
||||
|
||||
while( ssl->out_left > 0 )
|
||||
{
|
||||
SSL_DEBUG_MSG( 2, ( "message length: %d, out_left: %d",
|
||||
5 + ssl->out_msglen, ssl->out_left ) );
|
||||
ssl_hdr_len( ssl ) + ssl->out_msglen, ssl->out_left ) );
|
||||
|
||||
buf = ssl->out_hdr + 5 + ssl->out_msglen - ssl->out_left;
|
||||
buf = ssl->out_hdr + ssl_hdr_len( ssl ) +
|
||||
ssl->out_msglen - ssl->out_left;
|
||||
ret = ssl->f_send( ssl->p_send, buf, ssl->out_left );
|
||||
|
||||
SSL_DEBUG_RET( 2, "ssl->f_send", ret );
|
||||
|
@ -1881,6 +1880,18 @@ int ssl_flush_output( ssl_context *ssl )
|
|||
ssl->out_left -= ret;
|
||||
}
|
||||
|
||||
// TODO: adapt for DTLS (start from i = 6)
|
||||
for( i = 8; i > 0; i-- )
|
||||
if( ++ssl->out_ctr[i - 1] != 0 )
|
||||
break;
|
||||
|
||||
/* The loop goes to its end iff the counter is wrapping */
|
||||
if( i == 0 )
|
||||
{
|
||||
SSL_DEBUG_MSG( 1, ( "outgoing message counter would wrap" ) );
|
||||
return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
|
||||
}
|
||||
|
||||
SSL_DEBUG_MSG( 2, ( "<= flush output" ) );
|
||||
|
||||
return( 0 );
|
||||
|
@ -1958,7 +1969,7 @@ int ssl_write_record( ssl_context *ssl )
|
|||
ssl->out_len[1] = (unsigned char)( len );
|
||||
}
|
||||
|
||||
ssl->out_left = 5 + ssl->out_msglen;
|
||||
ssl->out_left = ssl_hdr_len( ssl ) + ssl->out_msglen;
|
||||
|
||||
SSL_DEBUG_MSG( 3, ( "output record: msgtype = %d, "
|
||||
"version = [%d:%d], msglen = %d",
|
||||
|
@ -1966,7 +1977,7 @@ int ssl_write_record( ssl_context *ssl )
|
|||
( ssl->out_len[0] << 8 ) | ssl->out_len[1] ) );
|
||||
|
||||
SSL_DEBUG_BUF( 4, "output record sent to network",
|
||||
ssl->out_hdr, 5 + ssl->out_msglen );
|
||||
ssl->out_hdr, ssl_hdr_len( ssl ) + ssl->out_msglen );
|
||||
}
|
||||
|
||||
if( ( ret = ssl_flush_output( ssl ) ) != 0 )
|
||||
|
@ -2028,7 +2039,7 @@ int ssl_read_record( ssl_context *ssl )
|
|||
/*
|
||||
* Read the record header and validate it
|
||||
*/
|
||||
if( ( ret = ssl_fetch_input( ssl, 5 ) ) != 0 )
|
||||
if( ( ret = ssl_fetch_input( ssl, ssl_hdr_len( ssl ) ) ) != 0 )
|
||||
{
|
||||
SSL_DEBUG_RET( 1, "ssl_fetch_input", ret );
|
||||
return( ret );
|
||||
|
@ -2110,14 +2121,15 @@ int ssl_read_record( ssl_context *ssl )
|
|||
/*
|
||||
* Read and optionally decrypt the message contents
|
||||
*/
|
||||
if( ( ret = ssl_fetch_input( ssl, 5 + ssl->in_msglen ) ) != 0 )
|
||||
if( ( ret = ssl_fetch_input( ssl,
|
||||
ssl_hdr_len( ssl ) + ssl->in_msglen ) ) != 0 )
|
||||
{
|
||||
SSL_DEBUG_RET( 1, "ssl_fetch_input", ret );
|
||||
return( ret );
|
||||
}
|
||||
|
||||
SSL_DEBUG_BUF( 4, "input record from network",
|
||||
ssl->in_hdr, 5 + ssl->in_msglen );
|
||||
ssl->in_hdr, ssl_hdr_len( ssl ) + ssl->in_msglen );
|
||||
|
||||
#if defined(POLARSSL_SSL_HW_RECORD_ACCEL)
|
||||
if( ssl_hw_record_read != NULL )
|
||||
|
@ -3417,39 +3429,27 @@ int ssl_init( ssl_context *ssl )
|
|||
#endif
|
||||
|
||||
/*
|
||||
* Prepare base structures (assume TLS for now)
|
||||
* Prepare base structures
|
||||
*/
|
||||
ssl->in_buf = (unsigned char *) polarssl_malloc( len );
|
||||
ssl->in_ctr = ssl->in_buf;
|
||||
ssl->in_hdr = ssl->in_buf + 8;
|
||||
ssl->in_len = ssl->in_buf + 11;
|
||||
ssl->in_iv = ssl->in_buf + 13;
|
||||
ssl->in_msg = ssl->in_buf + 13;
|
||||
|
||||
if( ssl->in_buf == NULL )
|
||||
{
|
||||
SSL_DEBUG_MSG( 1, ( "malloc(%d bytes) failed", len ) );
|
||||
return( POLARSSL_ERR_SSL_MALLOC_FAILED );
|
||||
}
|
||||
|
||||
ssl->out_buf = (unsigned char *) polarssl_malloc( len );
|
||||
ssl->out_ctr = ssl->out_buf;
|
||||
ssl->out_hdr = ssl->out_buf + 8;
|
||||
ssl->out_len = ssl->out_buf + 11;
|
||||
ssl->out_iv = ssl->out_buf + 13;
|
||||
ssl->out_msg = ssl->out_buf + 13;
|
||||
|
||||
if( ssl->out_buf == NULL )
|
||||
if( ssl->in_buf == NULL || ssl->out_buf == NULL )
|
||||
{
|
||||
SSL_DEBUG_MSG( 1, ( "malloc(%d bytes) failed", len ) );
|
||||
polarssl_free( ssl->in_buf );
|
||||
polarssl_free( ssl->out_buf );
|
||||
ssl->in_buf = NULL;
|
||||
ssl->out_buf = NULL;
|
||||
return( POLARSSL_ERR_SSL_MALLOC_FAILED );
|
||||
}
|
||||
|
||||
memset( ssl-> in_buf, 0, SSL_BUFFER_LEN );
|
||||
memset( ssl->out_buf, 0, SSL_BUFFER_LEN );
|
||||
|
||||
/* No error is possible, SSL_TRANSPORT_STREAM always valid */
|
||||
(void) ssl_set_transport( ssl, SSL_TRANSPORT_STREAM );
|
||||
|
||||
#if defined(POLARSSL_SSL_SESSION_TICKETS)
|
||||
ssl->ticket_lifetime = SSL_DEFAULT_TICKET_LIFETIME;
|
||||
#endif
|
||||
|
@ -3617,6 +3617,18 @@ int ssl_set_transport( ssl_context *ssl, int transport )
|
|||
{
|
||||
ssl->transport = transport;
|
||||
|
||||
ssl->out_hdr = ssl->out_buf;
|
||||
ssl->out_ctr = ssl->out_buf + 3;
|
||||
ssl->out_len = ssl->out_buf + 11;
|
||||
ssl->out_iv = ssl->out_buf + 13;
|
||||
ssl->out_msg = ssl->out_buf + 13;
|
||||
|
||||
ssl->in_hdr = ssl->in_buf;
|
||||
ssl->in_ctr = ssl->in_buf + 3;
|
||||
ssl->in_len = ssl->in_buf + 11;
|
||||
ssl->in_iv = ssl->in_buf + 13;
|
||||
ssl->in_msg = ssl->in_buf + 13;
|
||||
|
||||
/* DTLS starts with TLS1.1 */
|
||||
if( ssl->min_minor_ver < SSL_MINOR_VERSION_2 )
|
||||
ssl->min_minor_ver = SSL_MINOR_VERSION_2;
|
||||
|
@ -3631,6 +3643,19 @@ int ssl_set_transport( ssl_context *ssl, int transport )
|
|||
if( transport == SSL_TRANSPORT_STREAM )
|
||||
{
|
||||
ssl->transport = transport;
|
||||
|
||||
ssl->out_ctr = ssl->out_buf;
|
||||
ssl->out_hdr = ssl->out_buf + 8;
|
||||
ssl->out_len = ssl->out_buf + 11;
|
||||
ssl->out_iv = ssl->out_buf + 13;
|
||||
ssl->out_msg = ssl->out_buf + 13;
|
||||
|
||||
ssl->in_ctr = ssl->in_buf;
|
||||
ssl->in_hdr = ssl->in_buf + 8;
|
||||
ssl->in_len = ssl->in_buf + 11;
|
||||
ssl->in_iv = ssl->in_buf + 13;
|
||||
ssl->in_msg = ssl->in_buf + 13;
|
||||
|
||||
return( 0 );
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue