Use a specific function in the PSK callback
This commit is contained in:
parent
0a4fb09534
commit
4b68296626
3 changed files with 89 additions and 20 deletions
|
@ -672,6 +672,10 @@ struct mbedtls_ssl_handshake_params
|
||||||
#if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C)
|
#if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C)
|
||||||
const mbedtls_ecp_curve_info **curves; /*!< Supported elliptic curves */
|
const mbedtls_ecp_curve_info **curves; /*!< Supported elliptic curves */
|
||||||
#endif
|
#endif
|
||||||
|
#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
|
||||||
|
unsigned char *psk; /*!< PSK from the callback */
|
||||||
|
size_t psk_len; /*!< Length of PSK from callback */
|
||||||
|
#endif
|
||||||
#if defined(MBEDTLS_X509_CRT_PARSE_C)
|
#if defined(MBEDTLS_X509_CRT_PARSE_C)
|
||||||
/**
|
/**
|
||||||
* Current key/cert or key/cert list.
|
* Current key/cert or key/cert list.
|
||||||
|
@ -1581,8 +1585,10 @@ int mbedtls_ssl_set_own_cert( mbedtls_ssl_context *ssl, mbedtls_x509_crt *own_ce
|
||||||
|
|
||||||
#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
|
#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
|
||||||
/**
|
/**
|
||||||
* \brief Set the Pre Shared Key (PSK) and the identity name connected
|
* \brief Set the Pre Shared Key (PSK) and the expected identity name
|
||||||
* to it.
|
*
|
||||||
|
* \note This is mainly useful for clients. Servers will usually
|
||||||
|
* want to use \c mbedtls_ssl_set_psk_cb() instead.
|
||||||
*
|
*
|
||||||
* \param ssl SSL context
|
* \param ssl SSL context
|
||||||
* \param psk pointer to the pre-shared key
|
* \param psk pointer to the pre-shared key
|
||||||
|
@ -1592,11 +1598,28 @@ int mbedtls_ssl_set_own_cert( mbedtls_ssl_context *ssl, mbedtls_x509_crt *own_ce
|
||||||
*
|
*
|
||||||
* \return 0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED
|
* \return 0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED
|
||||||
*/
|
*/
|
||||||
int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len,
|
int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl,
|
||||||
const unsigned char *psk_identity, size_t psk_identity_len );
|
const unsigned char *psk, size_t psk_len,
|
||||||
|
const unsigned char *psk_identity, size_t psk_identity_len );
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Set the PSK callback (server-side only) (Optional).
|
* \brief Set the Pre Shared Key (PSK) for the current handshake
|
||||||
|
*
|
||||||
|
* \note This should only be called inside the PSK callback,
|
||||||
|
* ie the function passed to \c mbedtls_ssl_set_psk_cb().
|
||||||
|
*
|
||||||
|
* \param ssl SSL context
|
||||||
|
* \param psk pointer to the pre-shared key
|
||||||
|
* \param psk_len pre-shared key length
|
||||||
|
*
|
||||||
|
* \return 0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED
|
||||||
|
*/
|
||||||
|
int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
|
||||||
|
const unsigned char *psk, size_t psk_len );
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Set the PSK callback (server-side only).
|
||||||
*
|
*
|
||||||
* If set, the PSK callback is called for each
|
* If set, the PSK callback is called for each
|
||||||
* handshake where a PSK ciphersuite was negotiated.
|
* handshake where a PSK ciphersuite was negotiated.
|
||||||
|
@ -1607,10 +1630,14 @@ int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, siz
|
||||||
* mbedtls_ssl_context *ssl, const unsigned char *psk_identity,
|
* mbedtls_ssl_context *ssl, const unsigned char *psk_identity,
|
||||||
* size_t identity_len)
|
* size_t identity_len)
|
||||||
* If a valid PSK identity is found, the callback should use
|
* If a valid PSK identity is found, the callback should use
|
||||||
* mbedtls_ssl_set_psk() on the ssl context to set the correct PSK and
|
* \c mbedtls_ssl_set_hs_psk() on the ssl context to set the
|
||||||
* identity and return 0.
|
* correct PSK and return 0.
|
||||||
* Any other return value will result in a denied PSK identity.
|
* Any other return value will result in a denied PSK identity.
|
||||||
*
|
*
|
||||||
|
* \note If you set a PSK callback using this function, then you
|
||||||
|
* don't need to set a PSK key and identity using
|
||||||
|
* \c mbedtls_ssl_set_psk().
|
||||||
|
*
|
||||||
* \param conf SSL configuration
|
* \param conf SSL configuration
|
||||||
* \param f_psk PSK identity function
|
* \param f_psk PSK identity function
|
||||||
* \param p_psk PSK identity parameter
|
* \param p_psk PSK identity parameter
|
||||||
|
|
|
@ -1066,6 +1066,15 @@ int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exch
|
||||||
{
|
{
|
||||||
unsigned char *p = ssl->handshake->premaster;
|
unsigned char *p = ssl->handshake->premaster;
|
||||||
unsigned char *end = p + sizeof( ssl->handshake->premaster );
|
unsigned char *end = p + sizeof( ssl->handshake->premaster );
|
||||||
|
const unsigned char *psk = ssl->conf->psk;
|
||||||
|
size_t psk_len = ssl->conf->psk_len;
|
||||||
|
|
||||||
|
/* If the psk callback was called, use its result */
|
||||||
|
if( ssl->handshake->psk != NULL )
|
||||||
|
{
|
||||||
|
psk = ssl->handshake->psk;
|
||||||
|
psk_len = ssl->handshake->psk_len;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* PMS = struct {
|
* PMS = struct {
|
||||||
|
@ -1077,12 +1086,12 @@ int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exch
|
||||||
#if defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
|
#if defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
|
||||||
if( key_ex == MBEDTLS_KEY_EXCHANGE_PSK )
|
if( key_ex == MBEDTLS_KEY_EXCHANGE_PSK )
|
||||||
{
|
{
|
||||||
if( end - p < 2 + (int) ssl->conf->psk_len )
|
if( end - p < 2 + (int) psk_len )
|
||||||
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
||||||
|
|
||||||
*(p++) = (unsigned char)( ssl->conf->psk_len >> 8 );
|
*(p++) = (unsigned char)( psk_len >> 8 );
|
||||||
*(p++) = (unsigned char)( ssl->conf->psk_len );
|
*(p++) = (unsigned char)( psk_len );
|
||||||
p += ssl->conf->psk_len;
|
p += psk_len;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#endif /* MBEDTLS_KEY_EXCHANGE_PSK_ENABLED */
|
#endif /* MBEDTLS_KEY_EXCHANGE_PSK_ENABLED */
|
||||||
|
@ -1149,13 +1158,13 @@ int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exch
|
||||||
}
|
}
|
||||||
|
|
||||||
/* opaque psk<0..2^16-1>; */
|
/* opaque psk<0..2^16-1>; */
|
||||||
if( end - p < 2 + (int) ssl->conf->psk_len )
|
if( end - p < 2 + (int) psk_len )
|
||||||
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
||||||
|
|
||||||
*(p++) = (unsigned char)( ssl->conf->psk_len >> 8 );
|
*(p++) = (unsigned char)( psk_len >> 8 );
|
||||||
*(p++) = (unsigned char)( ssl->conf->psk_len );
|
*(p++) = (unsigned char)( psk_len );
|
||||||
memcpy( p, ssl->conf->psk, ssl->conf->psk_len );
|
memcpy( p, psk, psk_len );
|
||||||
p += ssl->conf->psk_len;
|
p += psk_len;
|
||||||
|
|
||||||
ssl->handshake->pmslen = p - ssl->handshake->premaster;
|
ssl->handshake->pmslen = p - ssl->handshake->premaster;
|
||||||
|
|
||||||
|
@ -5353,8 +5362,9 @@ int mbedtls_ssl_set_own_cert( mbedtls_ssl_context *ssl, mbedtls_x509_crt *own_ce
|
||||||
#endif /* MBEDTLS_X509_CRT_PARSE_C */
|
#endif /* MBEDTLS_X509_CRT_PARSE_C */
|
||||||
|
|
||||||
#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
|
#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
|
||||||
int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len,
|
int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl,
|
||||||
const unsigned char *psk_identity, size_t psk_identity_len )
|
const unsigned char *psk, size_t psk_len,
|
||||||
|
const unsigned char *psk_identity, size_t psk_identity_len )
|
||||||
{
|
{
|
||||||
if( psk == NULL || psk_identity == NULL )
|
if( psk == NULL || psk_identity == NULL )
|
||||||
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
||||||
|
@ -5385,6 +5395,31 @@ int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, siz
|
||||||
return( 0 );
|
return( 0 );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
|
||||||
|
const unsigned char *psk, size_t psk_len )
|
||||||
|
{
|
||||||
|
if( psk == NULL || ssl->handshake == NULL )
|
||||||
|
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
||||||
|
|
||||||
|
if( psk_len > MBEDTLS_PSK_MAX_LEN )
|
||||||
|
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
|
||||||
|
|
||||||
|
if( ssl->handshake->psk != NULL )
|
||||||
|
mbedtls_free( ssl->conf->psk );
|
||||||
|
|
||||||
|
if( ( ssl->handshake->psk = mbedtls_malloc( psk_len ) ) == NULL )
|
||||||
|
{
|
||||||
|
mbedtls_free( ssl->handshake->psk );
|
||||||
|
ssl->handshake->psk = NULL;
|
||||||
|
return( MBEDTLS_ERR_SSL_MALLOC_FAILED );
|
||||||
|
}
|
||||||
|
|
||||||
|
ssl->handshake->psk_len = psk_len;
|
||||||
|
memcpy( ssl->handshake->psk, psk, ssl->handshake->psk_len );
|
||||||
|
|
||||||
|
return( 0 );
|
||||||
|
}
|
||||||
|
|
||||||
void mbedtls_ssl_set_psk_cb( mbedtls_ssl_config *conf,
|
void mbedtls_ssl_set_psk_cb( mbedtls_ssl_config *conf,
|
||||||
int (*f_psk)(void *, mbedtls_ssl_context *, const unsigned char *,
|
int (*f_psk)(void *, mbedtls_ssl_context *, const unsigned char *,
|
||||||
size_t),
|
size_t),
|
||||||
|
@ -6441,6 +6476,14 @@ void mbedtls_ssl_handshake_free( mbedtls_ssl_handshake_params *handshake )
|
||||||
mbedtls_free( (void *) handshake->curves );
|
mbedtls_free( (void *) handshake->curves );
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
|
||||||
|
if( handshake->psk != NULL )
|
||||||
|
{
|
||||||
|
mbedtls_zeroize( handshake->psk, handshake->psk_len );
|
||||||
|
mbedtls_free( handshake->psk );
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
|
#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
|
||||||
defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
|
defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -678,8 +678,7 @@ int psk_callback( void *p_info, mbedtls_ssl_context *ssl,
|
||||||
if( name_len == strlen( cur->name ) &&
|
if( name_len == strlen( cur->name ) &&
|
||||||
memcmp( name, cur->name, name_len ) == 0 )
|
memcmp( name, cur->name, name_len ) == 0 )
|
||||||
{
|
{
|
||||||
return( mbedtls_ssl_set_psk( ssl, cur->key, cur->key_len,
|
return( mbedtls_ssl_set_hs_psk( ssl, cur->key, cur->key_len ) );
|
||||||
name, name_len ) );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = cur->next;
|
cur = cur->next;
|
||||||
|
|
Loading…
Reference in a new issue