From d9a94fe3d096d488fb54688033c10fbcfd980001 Mon Sep 17 00:00:00 2001 From: Jerry Yu Date: Tue, 28 Sep 2021 18:58:59 +0800 Subject: [PATCH] Add counter length macro Signed-off-by: Jerry Yu --- include/mbedtls/ssl.h | 4 ++-- library/ssl_misc.h | 20 ++++++++++---------- library/ssl_msg.c | 34 ++++++++++++++++++---------------- library/ssl_srv.c | 3 ++- library/ssl_tls.c | 19 +++++++++++-------- 5 files changed, 43 insertions(+), 37 deletions(-) diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index 3f627139c..d2f436138 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -594,7 +594,7 @@ union mbedtls_ssl_premaster_secret #define MBEDTLS_PREMASTER_SIZE sizeof( union mbedtls_ssl_premaster_secret ) /* Length of in_ctr buffer in mbedtls_ssl_session */ -#define MBEDTLS_SSL_IN_CTR_LEN 8 +#define MBEDTLS_SSL_COUNTER_LEN 8 #ifdef __cplusplus extern "C" { @@ -1555,7 +1555,7 @@ struct mbedtls_ssl_context size_t MBEDTLS_PRIVATE(out_buf_len); /*!< length of output buffer */ #endif - unsigned char MBEDTLS_PRIVATE(cur_out_ctr)[8]; /*!< Outgoing record sequence number. */ + unsigned char MBEDTLS_PRIVATE(cur_out_ctr)[MBEDTLS_SSL_COUNTER_LEN]; /*!< Outgoing record sequence number. */ #if defined(MBEDTLS_SSL_PROTO_DTLS) uint16_t MBEDTLS_PRIVATE(mtu); /*!< path mtu, used to fragment outgoing messages */ diff --git a/library/ssl_misc.h b/library/ssl_misc.h index ea891f44a..6f83fc327 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -573,8 +573,8 @@ struct mbedtls_ssl_handshake_params flight being received */ mbedtls_ssl_transform *alt_transform_out; /*!< Alternative transform for resending messages */ - unsigned char alt_out_ctr[8]; /*!< Alternative record epoch/counter - for resending messages */ + unsigned char alt_out_ctr[MBEDTLS_SSL_COUNTER_LEN]; /*!< Alternative record epoch/counter + for resending messages */ #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) /* The state of CID configuration in this handshake. */ @@ -873,14 +873,14 @@ static inline int mbedtls_ssl_transform_uses_aead( typedef struct { - uint8_t ctr[8]; /* In TLS: The implicit record sequence number. - * In DTLS: The 2-byte epoch followed by - * the 6-byte sequence number. - * This is stored as a raw big endian byte array - * as opposed to a uint64_t because we rarely - * need to perform arithmetic on this, but do - * need it as a Byte array for the purpose of - * MAC computations. */ + uint8_t ctr[MBEDTLS_SSL_COUNTER_LEN]; /* In TLS: The implicit record sequence number. + * In DTLS: The 2-byte epoch followed by + * the 6-byte sequence number. + * This is stored as a raw big endian byte array + * as opposed to a uint64_t because we rarely + * need to perform arithmetic on this, but do + * need it as a Byte array for the purpose of + * MAC computations. */ uint8_t type; /* The record content type. */ uint8_t ver[2]; /* SSL/TLS version as present on the wire. * Convert to internal presentation of versions diff --git a/library/ssl_msg.c b/library/ssl_msg.c index 518cfeeef..25e3ca3ec 100644 --- a/library/ssl_msg.c +++ b/library/ssl_msg.c @@ -2117,9 +2117,9 @@ static int ssl_swap_epochs( mbedtls_ssl_context *ssl ) ssl->handshake->alt_transform_out = tmp_transform; /* Swap epoch + sequence_number */ - memcpy( tmp_out_ctr, ssl->cur_out_ctr, 8 ); - memcpy( ssl->cur_out_ctr, ssl->handshake->alt_out_ctr, 8 ); - memcpy( ssl->handshake->alt_out_ctr, tmp_out_ctr, 8 ); + memcpy( tmp_out_ctr, ssl->cur_out_ctr, sizeof( ssl->cur_out_ctr ) ); + memcpy( ssl->cur_out_ctr, ssl->handshake->alt_out_ctr, sizeof( ssl->cur_out_ctr ) ); + memcpy( ssl->handshake->alt_out_ctr, tmp_out_ctr, sizeof( ssl->handshake->alt_out_ctr ) ); /* Adjust to the newly activated transform */ mbedtls_ssl_update_out_pointers( ssl, ssl->transform_out ); @@ -2562,7 +2562,7 @@ int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush ) mbedtls_ssl_write_version( ssl->major_ver, ssl->minor_ver, ssl->conf->transport, ssl->out_hdr + 1 ); - memcpy( ssl->out_ctr, ssl->cur_out_ctr, 8 ); + memcpy( ssl->out_ctr, ssl->cur_out_ctr, sizeof( ssl->cur_out_ctr ) ); MBEDTLS_PUT_UINT16_BE( len, ssl->out_len, 0); if( ssl->transform_out != NULL ) @@ -2574,7 +2574,7 @@ int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush ) rec.data_len = ssl->out_msglen; rec.data_offset = ssl->out_msg - rec.buf; - memcpy( &rec.ctr[0], ssl->out_ctr, 8 ); + memcpy( &rec.ctr[0], ssl->out_ctr, MBEDTLS_SSL_COUNTER_LEN ); mbedtls_ssl_write_version( ssl->major_ver, ssl->minor_ver, ssl->conf->transport, rec.ver ); rec.type = ssl->out_msgtype; @@ -3649,7 +3649,7 @@ static int ssl_prepare_record_content( mbedtls_ssl_context *ssl, #endif { unsigned i; - for( i = MBEDTLS_SSL_IN_CTR_LEN; i > mbedtls_ssl_ep_len( ssl ); i-- ) + for( i = MBEDTLS_SSL_COUNTER_LEN; i > mbedtls_ssl_ep_len( ssl ); i-- ) if( ++ssl->in_ctr[i - 1] != 0 ) break; @@ -4791,7 +4791,7 @@ int mbedtls_ssl_parse_change_cipher_spec( mbedtls_ssl_context *ssl ) } else #endif /* MBEDTLS_SSL_PROTO_DTLS */ - mbedtls_platform_zeroize( ssl->in_ctr, MBEDTLS_SSL_IN_CTR_LEN ); + mbedtls_platform_zeroize( ssl->in_ctr, MBEDTLS_SSL_COUNTER_LEN ); mbedtls_ssl_update_in_pointers( ssl ); @@ -4827,12 +4827,12 @@ void mbedtls_ssl_update_out_pointers( mbedtls_ssl_context *ssl, { ssl->out_ctr = ssl->out_hdr + 3; #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) - ssl->out_cid = ssl->out_ctr + 8; + ssl->out_cid = ssl->out_ctr + MBEDTLS_SSL_COUNTER_LEN; ssl->out_len = ssl->out_cid; if( transform != NULL ) ssl->out_len += transform->out_cid_len; #else /* MBEDTLS_SSL_DTLS_CONNECTION_ID */ - ssl->out_len = ssl->out_ctr + 8; + ssl->out_len = ssl->out_ctr + MBEDTLS_SSL_COUNTER_LEN; #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */ ssl->out_iv = ssl->out_len + 2; } @@ -4881,17 +4881,17 @@ void mbedtls_ssl_update_in_pointers( mbedtls_ssl_context *ssl ) * ssl_parse_record_header(). */ ssl->in_ctr = ssl->in_hdr + 3; #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) - ssl->in_cid = ssl->in_ctr + MBEDTLS_SSL_IN_CTR_LEN; + ssl->in_cid = ssl->in_ctr + MBEDTLS_SSL_COUNTER_LEN; ssl->in_len = ssl->in_cid; /* Default: no CID */ #else /* MBEDTLS_SSL_DTLS_CONNECTION_ID */ - ssl->in_len = ssl->in_ctr + MBEDTLS_SSL_IN_CTR_LEN; + ssl->in_len = ssl->in_ctr + MBEDTLS_SSL_COUNTER_LEN; #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */ ssl->in_iv = ssl->in_len + 2; } else #endif { - ssl->in_ctr = ssl->in_hdr - MBEDTLS_SSL_IN_CTR_LEN; + ssl->in_ctr = ssl->in_hdr - MBEDTLS_SSL_COUNTER_LEN; ssl->in_len = ssl->in_hdr + 3; #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) ssl->in_cid = ssl->in_len; @@ -5065,9 +5065,11 @@ static int ssl_check_ctr_renegotiate( mbedtls_ssl_context *ssl ) } in_ctr_cmp = memcmp( ssl->in_ctr + ep_len, - ssl->conf->renego_period + ep_len, 8 - ep_len ); - out_ctr_cmp = memcmp( ssl->cur_out_ctr + ep_len, - ssl->conf->renego_period + ep_len, 8 - ep_len ); + &ssl->conf->renego_period[ep_len], + MBEDTLS_SSL_COUNTER_LEN - ep_len ); + out_ctr_cmp = memcmp( &ssl->cur_out_ctr[ep_len], + &ssl->conf->renego_period[ep_len], + sizeof( ssl->cur_out_ctr ) - ep_len ); if( in_ctr_cmp <= 0 && out_ctr_cmp <= 0 ) { @@ -5558,7 +5560,7 @@ void mbedtls_ssl_set_inbound_transform( mbedtls_ssl_context *ssl, return; ssl->transform_in = transform; - mbedtls_platform_zeroize( ssl->in_ctr, MBEDTLS_SSL_IN_CTR_LEN ); + mbedtls_platform_zeroize( ssl->in_ctr, MBEDTLS_SSL_COUNTER_LEN ); } void mbedtls_ssl_set_outbound_transform( mbedtls_ssl_context *ssl, diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 147bb785d..79c160ea4 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -1220,7 +1220,8 @@ read_record_header: return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER ); } - memcpy( ssl->cur_out_ctr + 2, ssl->in_ctr + 2, MBEDTLS_SSL_IN_CTR_LEN - 2 ); + memcpy( &ssl->cur_out_ctr[2], ssl->in_ctr + 2, + MBEDTLS_SSL_COUNTER_LEN - 2 ); #if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY) if( mbedtls_ssl_dtls_replay_check( ssl ) != 0 ) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index ab36f5d89..b22db47b5 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -2820,10 +2820,13 @@ int mbedtls_ssl_write_finished( mbedtls_ssl_context *ssl ) /* Remember current epoch settings for resending */ ssl->handshake->alt_transform_out = ssl->transform_out; - memcpy( ssl->handshake->alt_out_ctr, ssl->cur_out_ctr, 8 ); + memcpy( ssl->handshake->alt_out_ctr, ssl->cur_out_ctr, + sizeof( ssl->cur_out_ctr ) ); /* Set sequence_number to zero */ - memset( ssl->cur_out_ctr + 2, 0, 6 ); + mbedtls_platform_zeroize( &ssl->cur_out_ctr[2], + sizeof( ssl->cur_out_ctr ) - 2 ); + /* Increment epoch */ for( i = 2; i > 0; i-- ) @@ -2839,7 +2842,7 @@ int mbedtls_ssl_write_finished( mbedtls_ssl_context *ssl ) } else #endif /* MBEDTLS_SSL_PROTO_DTLS */ - memset( ssl->cur_out_ctr, 0, 8 ); + mbedtls_platform_zeroize( ssl->cur_out_ctr, sizeof( ssl->cur_out_ctr ) ); ssl->transform_out = ssl->transform_negotiate; ssl->session_out = ssl->session_negotiate; @@ -3324,7 +3327,7 @@ static void ssl_session_reset_msg_layer( mbedtls_ssl_context *ssl, ssl->out_msglen = 0; ssl->out_left = 0; memset( ssl->out_buf, 0, out_buf_len ); - memset( ssl->cur_out_ctr, 0, sizeof( ssl->cur_out_ctr ) ); + mbedtls_platform_zeroize( ssl->cur_out_ctr, sizeof( ssl->cur_out_ctr ) ); ssl->transform_out = NULL; #if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY) @@ -5778,7 +5781,7 @@ int mbedtls_ssl_context_save( mbedtls_ssl_context *ssl, used += 8; if( used <= buf_len ) { - memcpy( p, ssl->cur_out_ctr, 8 ); + memcpy( p, ssl->cur_out_ctr, sizeof( ssl->cur_out_ctr ) ); p += 8; } @@ -6035,11 +6038,11 @@ static int ssl_context_load( mbedtls_ssl_context *ssl, ssl->disable_datagram_packing = *p++; #endif /* MBEDTLS_SSL_PROTO_DTLS */ - if( (size_t)( end - p ) < 8 ) + if( (size_t)( end - p ) < sizeof( ssl->cur_out_ctr ) ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - memcpy( ssl->cur_out_ctr, p, 8 ); - p += 8; + memcpy( ssl->cur_out_ctr, p, sizeof( ssl->cur_out_ctr ) ); + p += sizeof( ssl->cur_out_ctr ); #if defined(MBEDTLS_SSL_PROTO_DTLS) if( (size_t)( end - p ) < 2 )