diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index a72914fb3..e098bc9b5 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -672,6 +672,10 @@ struct mbedtls_ssl_handshake_params #if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C) const mbedtls_ecp_curve_info **curves; /*!< Supported elliptic curves */ #endif +#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) + unsigned char *psk; /*!< PSK from the callback */ + size_t psk_len; /*!< Length of PSK from callback */ +#endif #if defined(MBEDTLS_X509_CRT_PARSE_C) /** * Current key/cert or key/cert list. @@ -1581,8 +1585,10 @@ int mbedtls_ssl_set_own_cert( mbedtls_ssl_context *ssl, mbedtls_x509_crt *own_ce #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) /** - * \brief Set the Pre Shared Key (PSK) and the identity name connected - * to it. + * \brief Set the Pre Shared Key (PSK) and the expected identity name + * + * \note This is mainly useful for clients. Servers will usually + * want to use \c mbedtls_ssl_set_psk_cb() instead. * * \param ssl SSL context * \param psk pointer to the pre-shared key @@ -1592,11 +1598,28 @@ int mbedtls_ssl_set_own_cert( mbedtls_ssl_context *ssl, mbedtls_x509_crt *own_ce * * \return 0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED */ -int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len, - const unsigned char *psk_identity, size_t psk_identity_len ); +int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, + const unsigned char *psk, size_t psk_len, + const unsigned char *psk_identity, size_t psk_identity_len ); + /** - * \brief Set the PSK callback (server-side only) (Optional). + * \brief Set the Pre Shared Key (PSK) for the current handshake + * + * \note This should only be called inside the PSK callback, + * ie the function passed to \c mbedtls_ssl_set_psk_cb(). + * + * \param ssl SSL context + * \param psk pointer to the pre-shared key + * \param psk_len pre-shared key length + * + * \return 0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED + */ +int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, + const unsigned char *psk, size_t psk_len ); + +/** + * \brief Set the PSK callback (server-side only). * * If set, the PSK callback is called for each * handshake where a PSK ciphersuite was negotiated. @@ -1607,10 +1630,14 @@ int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, siz * mbedtls_ssl_context *ssl, const unsigned char *psk_identity, * size_t identity_len) * If a valid PSK identity is found, the callback should use - * mbedtls_ssl_set_psk() on the ssl context to set the correct PSK and - * identity and return 0. + * \c mbedtls_ssl_set_hs_psk() on the ssl context to set the + * correct PSK and return 0. * Any other return value will result in a denied PSK identity. * + * \note If you set a PSK callback using this function, then you + * don't need to set a PSK key and identity using + * \c mbedtls_ssl_set_psk(). + * * \param conf SSL configuration * \param f_psk PSK identity function * \param p_psk PSK identity parameter diff --git a/library/ssl_tls.c b/library/ssl_tls.c index c599761b0..d3ec5dcf5 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1066,6 +1066,15 @@ int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exch { unsigned char *p = ssl->handshake->premaster; unsigned char *end = p + sizeof( ssl->handshake->premaster ); + const unsigned char *psk = ssl->conf->psk; + size_t psk_len = ssl->conf->psk_len; + + /* If the psk callback was called, use its result */ + if( ssl->handshake->psk != NULL ) + { + psk = ssl->handshake->psk; + psk_len = ssl->handshake->psk_len; + } /* * PMS = struct { @@ -1077,12 +1086,12 @@ int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exch #if defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED) if( key_ex == MBEDTLS_KEY_EXCHANGE_PSK ) { - if( end - p < 2 + (int) ssl->conf->psk_len ) + if( end - p < 2 + (int) psk_len ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - *(p++) = (unsigned char)( ssl->conf->psk_len >> 8 ); - *(p++) = (unsigned char)( ssl->conf->psk_len ); - p += ssl->conf->psk_len; + *(p++) = (unsigned char)( psk_len >> 8 ); + *(p++) = (unsigned char)( psk_len ); + p += psk_len; } else #endif /* MBEDTLS_KEY_EXCHANGE_PSK_ENABLED */ @@ -1149,13 +1158,13 @@ int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exch } /* opaque psk<0..2^16-1>; */ - if( end - p < 2 + (int) ssl->conf->psk_len ) + if( end - p < 2 + (int) psk_len ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - *(p++) = (unsigned char)( ssl->conf->psk_len >> 8 ); - *(p++) = (unsigned char)( ssl->conf->psk_len ); - memcpy( p, ssl->conf->psk, ssl->conf->psk_len ); - p += ssl->conf->psk_len; + *(p++) = (unsigned char)( psk_len >> 8 ); + *(p++) = (unsigned char)( psk_len ); + memcpy( p, psk, psk_len ); + p += psk_len; ssl->handshake->pmslen = p - ssl->handshake->premaster; @@ -5353,8 +5362,9 @@ int mbedtls_ssl_set_own_cert( mbedtls_ssl_context *ssl, mbedtls_x509_crt *own_ce #endif /* MBEDTLS_X509_CRT_PARSE_C */ #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) -int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len, - const unsigned char *psk_identity, size_t psk_identity_len ) +int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, + const unsigned char *psk, size_t psk_len, + const unsigned char *psk_identity, size_t psk_identity_len ) { if( psk == NULL || psk_identity == NULL ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); @@ -5385,6 +5395,31 @@ int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, siz return( 0 ); } +int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, + const unsigned char *psk, size_t psk_len ) +{ + if( psk == NULL || ssl->handshake == NULL ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( psk_len > MBEDTLS_PSK_MAX_LEN ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( ssl->handshake->psk != NULL ) + mbedtls_free( ssl->conf->psk ); + + if( ( ssl->handshake->psk = mbedtls_malloc( psk_len ) ) == NULL ) + { + mbedtls_free( ssl->handshake->psk ); + ssl->handshake->psk = NULL; + return( MBEDTLS_ERR_SSL_MALLOC_FAILED ); + } + + ssl->handshake->psk_len = psk_len; + memcpy( ssl->handshake->psk, psk, ssl->handshake->psk_len ); + + return( 0 ); +} + void mbedtls_ssl_set_psk_cb( mbedtls_ssl_config *conf, int (*f_psk)(void *, mbedtls_ssl_context *, const unsigned char *, size_t), @@ -6441,6 +6476,14 @@ void mbedtls_ssl_handshake_free( mbedtls_ssl_handshake_params *handshake ) mbedtls_free( (void *) handshake->curves ); #endif +#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) + if( handshake->psk != NULL ) + { + mbedtls_zeroize( handshake->psk, handshake->psk_len ); + mbedtls_free( handshake->psk ); + } +#endif + #if defined(MBEDTLS_X509_CRT_PARSE_C) && \ defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) /* diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c index e3e680ce5..12614c128 100644 --- a/programs/ssl/ssl_server2.c +++ b/programs/ssl/ssl_server2.c @@ -678,8 +678,7 @@ int psk_callback( void *p_info, mbedtls_ssl_context *ssl, if( name_len == strlen( cur->name ) && memcmp( name, cur->name, name_len ) == 0 ) { - return( mbedtls_ssl_set_psk( ssl, cur->key, cur->key_len, - name, name_len ) ); + return( mbedtls_ssl_set_hs_psk( ssl, cur->key, cur->key_len ) ); } cur = cur->next;