diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 9a1d22044..442dba28f 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -579,6 +579,7 @@ struct _ssl_context * PKI layer */ pk_context *pk_key; /*!< own private key */ + int pk_key_own_alloc; /*!< did we allocate pk_key? */ #if defined(POLARSSL_RSA_C) int rsa_use_alt; /*minor_ver == SSL_MINOR_VERSION_3 ) ssl->out_msg[5] = SSL_SIG_RSA; - if( ssl->rsa_use_alt ) + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + ssl->out_msg + 6 + offset, &n, + ssl->f_rng, ssl->p_rng ) ) != 0 ) { - if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, - RSA_PRIVATE, md_alg, - hashlen, hash, ssl->out_msg + 6 + offset ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "rsa_sign", ret ); - return( ret ); - } - - n = ssl->rsa_key_len ( ssl->rsa_key ); - } - else - { - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, - ssl->out_msg + 6 + offset, &n, - ssl->f_rng, ssl->p_rng ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "pk_sign", ret ); - return( ret ); - } + SSL_DEBUG_RET( 1, "pk_sign", ret ); + return( ret ); } } else diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 6fb16ecab..0fa4f66f0 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -2080,27 +2080,12 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) n += 2; } - if( ssl->rsa_use_alt ) + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + p + 2 , &signature_len, + ssl->f_rng, ssl->p_rng ) ) != 0 ) { - if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, - ssl->p_rng, RSA_PRIVATE, md_alg, hashlen, - hash, p + 2 ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "rsa_sign", ret ); - return( ret ); - } - - signature_len = ssl->rsa_key_len( ssl->rsa_key ); - } - else - { - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, - p + 2 , &signature_len, - ssl->f_rng, ssl->p_rng ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "pk_sign", ret ); - return( ret ); - } + SSL_DEBUG_RET( 1, "pk_sign", ret ); + return( ret ); } } else @@ -2289,21 +2274,11 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE ); } - if( ssl->rsa_use_alt ) { - ret = ssl->rsa_decrypt( ssl->rsa_key, RSA_PRIVATE, - &ssl->handshake->pmslen, - ssl->in_msg + i, - ssl->handshake->premaster, - sizeof(ssl->handshake->premaster) ); - } - else - { - ret = pk_decrypt( ssl->pk_key, - ssl->in_msg + i, n, - ssl->handshake->premaster, &ssl->handshake->pmslen, - sizeof(ssl->handshake->premaster), - ssl->f_rng, ssl->p_rng ); - } + ret = pk_decrypt( ssl->pk_key, + ssl->in_msg + i, n, + ssl->handshake->premaster, &ssl->handshake->pmslen, + sizeof(ssl->handshake->premaster), + ssl->f_rng, ssl->p_rng ); if( ret != 0 || ssl->handshake->pmslen != 48 || ssl->handshake->premaster[0] != ssl->handshake->max_major_ver || diff --git a/library/ssl_tls.c b/library/ssl_tls.c index d4723d759..9e446f613 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3162,18 +3162,30 @@ void ssl_set_own_cert_rsa( ssl_context *ssl, x509_cert *own_cert, } #endif /* POLARSSL_RSA_C */ -void ssl_set_own_cert_alt_rsa( ssl_context *ssl, x509_cert *own_cert, +int ssl_set_own_cert_alt_rsa( ssl_context *ssl, x509_cert *own_cert, void *rsa_key, rsa_decrypt_func rsa_decrypt, rsa_sign_func rsa_sign, rsa_key_len_func rsa_key_len ) { + int ret; + ssl->own_cert = own_cert; ssl->rsa_use_alt = 1; ssl->rsa_key = rsa_key; ssl->rsa_decrypt = rsa_decrypt; ssl->rsa_sign = rsa_sign; ssl->rsa_key_len = rsa_key_len; + + if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL ) + return( POLARSSL_ERR_SSL_MALLOC_FAILED ); + + ssl->pk_key_own_alloc = 1; + + pk_init( ssl->pk_key ); + + return( pk_init_ctx_rsa_alt( ssl->pk_key, rsa_key, + rsa_decrypt, rsa_sign, rsa_key_len ) ); } #endif /* POLARSSL_X509_PARSE_C */ @@ -3780,6 +3792,12 @@ void ssl_free( ssl_context *ssl ) ssl->hostname_len = 0; } + if( ssl->pk_key_own_alloc ) + { + pk_free( ssl->pk_key ); + polarssl_free( ssl->pk_key ); + } + #if defined(POLARSSL_SSL_HW_RECORD_ACCEL) if( ssl_hw_record_finish != NULL ) {