diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 93b3170ba..c764961d0 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -56,9 +56,8 @@ #if defined(POLARSSL_X509_CRT_PARSE_C) #include "x509_crt.h" -#endif - #include "x509_crl.h" +#endif #if defined(POLARSSL_DHM_C) #include "dhm.h" @@ -394,6 +393,9 @@ typedef struct _ssl_handshake_params ssl_handshake_params; #if defined(POLARSSL_SSL_SESSION_TICKETS) typedef struct _ssl_ticket_keys ssl_ticket_keys; #endif +#if defined(POLARSSL_X509_CRT_PARSE_C) +typedef struct _ssl_key_cert ssl_key_cert; +#endif /* * This structure is used for storing current session data. @@ -543,6 +545,19 @@ struct _ssl_ticket_keys }; #endif /* POLARSSL_SSL_SESSION_TICKETS */ +#if defined(POLARSSL_X509_CRT_PARSE_C) +/* + * List of certificate + private key pairs + */ +struct _ssl_key_cert +{ + x509_crt *cert; /*!< cert */ + pk_context *key; /*!< private key */ + int key_own_alloc; /*!< did we allocate key? */ + ssl_key_cert *next; /*!< next key/cert pair */ +}; +#endif /* POLARSSL_X509_CRT_PARSE_C */ + struct _ssl_context { /* @@ -647,22 +662,18 @@ struct _ssl_context /* * PKI layer */ -#if defined(POLARSSL_PK_C) - pk_context *pk_key; /*!< own private key */ - int pk_key_own_alloc; /*!< did we allocate pk_key? */ -#endif - #if defined(POLARSSL_X509_CRT_PARSE_C) - x509_crt *own_cert; /*!< own X.509 certificate */ - x509_crt *ca_chain; /*!< own trusted CA chain */ - const char *peer_cn; /*!< expected peer CN */ -#endif /* POLARSSL_X509_CRT_PARSE_C */ - x509_crl *ca_crl; /*!< trusted CA CRLs */ + ssl_key_cert *key_cert; /*!< own certificate(s)/key(s) */ + + x509_crt *ca_chain; /*!< own trusted CA chain */ + x509_crl *ca_crl; /*!< trusted CA CRLs */ + const char *peer_cn; /*!< expected peer CN */ +#endif /* POLARSSL_X509_CRT_PARSE_C */ -#if defined(POLARSSL_SSL_SESSION_TICKETS) /* * Support for generating and checking session tickets */ +#if defined(POLARSSL_SSL_SESSION_TICKETS) ssl_ticket_keys *ticket_keys; /*!< keys for ticket encryption */ #endif /* POLARSSL_SSL_SESSION_TICKETS */ @@ -966,15 +977,22 @@ void ssl_set_ca_chain( ssl_context *ssl, x509_crt *ca_chain, /** * \brief Set own certificate chain and private key * - * Note: own_cert should contain IN order from the bottom - * up your certificate chain. The top certificate (self-signed) + * \note own_cert should contain in order from the bottom up your + * certificate chain. The top certificate (self-signed) * can be omitted. * + * \note This function may be called more than once if you want to + * support multiple certificates (eg, one using RSA and one + * using ECDSA). However, on client, currently only the first + * certificate is used (subsequent calls have no effect). + * * \param ssl SSL context * \param own_cert own public certificate chain * \param pk_key own private key + * + * \return 0 on success or POLARSSL_ERR_SSL_MALLOC_FAILED */ -void ssl_set_own_cert( ssl_context *ssl, x509_crt *own_cert, +int ssl_set_own_cert( ssl_context *ssl, x509_crt *own_cert, pk_context *pk_key ); #if defined(POLARSSL_RSA_C) @@ -1496,6 +1514,18 @@ pk_type_t ssl_pk_alg_from_sig( unsigned char sig ); md_type_t ssl_md_alg_from_hash( unsigned char hash ); +#if defined(POLARSSL_X509_CRT_PARSE_C) +static inline pk_context *ssl_own_key( ssl_context *ssl ) +{ + return( ssl->key_cert == NULL ? NULL : ssl->key_cert->key ); +} + +static inline x509_crt *ssl_own_cert( ssl_context *ssl ) +{ + return( ssl->key_cert == NULL ? NULL : ssl->key_cert->cert ); +} +#endif /* POLARSSL_X509_CRT_PARSE_C */ + #ifdef __cplusplus } #endif diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 7d1a83285..ae8c916c8 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1595,7 +1595,7 @@ static int ssl_parse_certificate_request( ssl_context *ssl ) { #if defined(POLARSSL_RSA_C) if( *p == SSL_CERT_TYPE_RSA_SIGN && - pk_can_do( ssl->pk_key, POLARSSL_PK_RSA ) ) + pk_can_do( ssl_own_key( ssl ), POLARSSL_PK_RSA ) ) { ssl->handshake->cert_type = SSL_CERT_TYPE_RSA_SIGN; break; @@ -1604,7 +1604,7 @@ static int ssl_parse_certificate_request( ssl_context *ssl ) #endif #if defined(POLARSSL_ECDSA_C) if( *p == SSL_CERT_TYPE_ECDSA_SIGN && - pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) ) + pk_can_do( ssl_own_key( ssl ), POLARSSL_PK_ECDSA ) ) { ssl->handshake->cert_type = SSL_CERT_TYPE_ECDSA_SIGN; break; @@ -2005,14 +2005,14 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) return( 0 ); } - if( ssl->client_auth == 0 || ssl->own_cert == NULL ) + if( ssl->client_auth == 0 || ssl_own_cert( ssl ) == NULL ) { SSL_DEBUG_MSG( 2, ( "<= skip write certificate verify" ) ); ssl->state++; return( 0 ); } - if( ssl->pk_key == NULL ) + if( ssl_own_key( ssl ) == NULL ) { SSL_DEBUG_MSG( 1, ( "got no private key" ) ); return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); @@ -2045,7 +2045,7 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) /* * For ECDSA, default hash is SHA-1 only */ - if( pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) ) + if( pk_can_do( ssl_own_key( ssl ), POLARSSL_PK_ECDSA ) ) { hash_start += 16; hashlen -= 16; @@ -2084,7 +2084,7 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) md_alg = POLARSSL_MD_SHA256; ssl->out_msg[4] = SSL_HASH_SHA256; } - ssl->out_msg[5] = ssl_sig_from_pk( ssl->pk_key ); + ssl->out_msg[5] = ssl_sig_from_pk( ssl_own_key( ssl ) ); /* Info from md_alg will be used instead */ hashlen = 0; @@ -2097,7 +2097,7 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE ); } - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash_start, hashlen, + if( ( ret = pk_sign( ssl_own_key( ssl ), md_alg, hash_start, hashlen, ssl->out_msg + 6 + offset, &n, ssl->f_rng, ssl->p_rng ) ) != 0 ) { diff --git a/library/ssl_srv.c b/library/ssl_srv.c index e6ce88c0c..47e3e272c 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -1306,8 +1306,8 @@ static int ssl_parse_client_hello( ssl_context *ssl ) #if defined(POLARSSL_PK_C) pk_alg = ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info ); if( pk_alg != POLARSSL_PK_NONE && - ( ssl->pk_key == NULL || - ! pk_can_do( ssl->pk_key, pk_alg ) ) ) + ( ssl_own_key( ssl ) == NULL || + ! pk_can_do( ssl_own_key( ssl ), pk_alg ) ) ) continue; #endif @@ -2065,7 +2065,7 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) /* * Make the signature */ - if( ssl->pk_key == NULL ) + if( ssl_own_key( ssl ) == NULL ) { SSL_DEBUG_MSG( 1, ( "got no private key" ) ); return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); @@ -2075,13 +2075,13 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) { *(p++) = ssl->handshake->sig_alg; - *(p++) = ssl_sig_from_pk( ssl->pk_key ); + *(p++) = ssl_sig_from_pk( ssl_own_key( ssl ) ); n += 2; } #endif /* POLARSSL_SSL_PROTO_TLS1_2 */ - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + if( ( ret = pk_sign( ssl_own_key( ssl ), md_alg, hash, hashlen, p + 2 , &signature_len, ssl->f_rng, ssl->p_rng ) ) != 0 ) { @@ -2221,7 +2221,7 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE; size_t i, n = 0; - if( ! pk_can_do( ssl->pk_key, POLARSSL_PK_RSA ) ) + if( ! pk_can_do( ssl_own_key( ssl ), POLARSSL_PK_RSA ) ) { SSL_DEBUG_MSG( 1, ( "got no RSA private key" ) ); return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); @@ -2231,7 +2231,7 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) * Decrypt the premaster using own private RSA key */ i = 4; - n = pk_get_len( ssl->pk_key ); + n = pk_get_len( ssl_own_key( ssl ) ); ssl->handshake->pmslen = 48; #if defined(POLARSSL_SSL_PROTO_TLS1) || defined(POLARSSL_SSL_PROTO_TLS1_1) || \ @@ -2254,7 +2254,7 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE ); } - ret = pk_decrypt( ssl->pk_key, + ret = pk_decrypt( ssl_own_key( ssl ), ssl->in_msg + i, n, ssl->handshake->premaster, &ssl->handshake->pmslen, sizeof(ssl->handshake->premaster), diff --git a/library/ssl_tls.c b/library/ssl_tls.c index a113ec1f2..7f5ea76bd 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -2302,7 +2302,7 @@ int ssl_write_certificate( ssl_context *ssl ) * If using SSLv3 and got no cert, send an Alert message * (otherwise an empty Certificate message will be sent). */ - if( ssl->own_cert == NULL && + if( ssl_own_cert( ssl ) == NULL && ssl->minor_ver == SSL_MINOR_VERSION_0 ) { ssl->out_msglen = 2; @@ -2317,14 +2317,14 @@ int ssl_write_certificate( ssl_context *ssl ) } else /* SSL_IS_SERVER */ { - if( ssl->own_cert == NULL ) + if( ssl_own_cert( ssl ) == NULL ) { SSL_DEBUG_MSG( 1, ( "got no certificate to send" ) ); return( POLARSSL_ERR_SSL_CERTIFICATE_REQUIRED ); } } - SSL_DEBUG_CRT( 3, "own certificate", ssl->own_cert ); + SSL_DEBUG_CRT( 3, "own certificate", ssl_own_cert( ssl ) ); /* * 0 . 0 handshake type @@ -2336,7 +2336,7 @@ int ssl_write_certificate( ssl_context *ssl ) * n+3 . ... upper level cert, etc. */ i = 7; - crt = ssl->own_cert; + crt = ssl_own_cert( ssl ); while( crt != NULL ) { @@ -3462,6 +3462,30 @@ void ssl_set_ciphersuites_for_version( ssl_context *ssl, const int *ciphersuites } #if defined(POLARSSL_X509_CRT_PARSE_C) +/* Add a new (empty) key_cert entry an return a pointer to it */ +static ssl_key_cert *ssl_add_key_cert( ssl_context *ssl ) +{ + ssl_key_cert *key_cert, *last; + + if( ( key_cert = polarssl_malloc( sizeof( ssl_key_cert ) ) ) == NULL ) + return( NULL ); + + memset( key_cert, 0, sizeof( ssl_key_cert ) ); + + /* Append the new key_cert to the (possibly empty) current list */ + if( ssl->key_cert == NULL ) + ssl->key_cert = key_cert; + else + { + last = ssl->key_cert; + while( last->next != NULL ) + last = last->next; + last->next = key_cert; + } + + return key_cert; +} + void ssl_set_ca_chain( ssl_context *ssl, x509_crt *ca_chain, x509_crl *ca_crl, const char *peer_cn ) { @@ -3470,11 +3494,18 @@ void ssl_set_ca_chain( ssl_context *ssl, x509_crt *ca_chain, ssl->peer_cn = peer_cn; } -void ssl_set_own_cert( ssl_context *ssl, x509_crt *own_cert, +int ssl_set_own_cert( ssl_context *ssl, x509_crt *own_cert, pk_context *pk_key ) { - ssl->own_cert = own_cert; - ssl->pk_key = pk_key; + ssl_key_cert *key_cert = ssl_add_key_cert( ssl ); + + if( key_cert == NULL ) + return( POLARSSL_ERR_SSL_MALLOC_FAILED ); + + key_cert->cert = own_cert; + key_cert->key = pk_key; + + return( 0 ); } #if defined(POLARSSL_RSA_C) @@ -3482,23 +3513,26 @@ int ssl_set_own_cert_rsa( ssl_context *ssl, x509_crt *own_cert, rsa_context *rsa_key ) { int ret; + ssl_key_cert *key_cert = ssl_add_key_cert( ssl ); - ssl->own_cert = own_cert; - - if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL ) + if( key_cert == NULL ) return( POLARSSL_ERR_SSL_MALLOC_FAILED ); - ssl->pk_key_own_alloc = 1; + if( ( key_cert->key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL ) + return( POLARSSL_ERR_SSL_MALLOC_FAILED ); - pk_init( ssl->pk_key ); + pk_init( key_cert->key ); - ret = pk_init_ctx( ssl->pk_key, pk_info_from_type( POLARSSL_PK_RSA ) ); + ret = pk_init_ctx( key_cert->key, pk_info_from_type( POLARSSL_PK_RSA ) ); if( ret != 0 ) return( ret ); - if( ( ret = rsa_copy( ssl->pk_key->pk_ctx, rsa_key ) ) != 0 ) + if( ( ret = rsa_copy( key_cert->key->pk_ctx, rsa_key ) ) != 0 ) return( ret ); + key_cert->cert = own_cert; + key_cert->key_own_alloc = 1; + return( 0 ); } #endif /* POLARSSL_RSA_C */ @@ -3509,17 +3543,25 @@ int ssl_set_own_cert_alt( ssl_context *ssl, x509_crt *own_cert, rsa_sign_func rsa_sign, rsa_key_len_func rsa_key_len ) { - ssl->own_cert = own_cert; + int ret; + ssl_key_cert *key_cert = ssl_add_key_cert( ssl ); - if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL ) + if( key_cert == NULL ) return( POLARSSL_ERR_SSL_MALLOC_FAILED ); - ssl->pk_key_own_alloc = 1; + if( ( key_cert->key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL ) + return( POLARSSL_ERR_SSL_MALLOC_FAILED ); - pk_init( ssl->pk_key ); + pk_init( key_cert->key ); - return( pk_init_ctx_rsa_alt( ssl->pk_key, rsa_key, - rsa_decrypt, rsa_sign, rsa_key_len ) ); + if( ( ret = pk_init_ctx_rsa_alt( key_cert->key, rsa_key, + rsa_decrypt, rsa_sign, rsa_key_len ) ) != 0 ) + return( ret ); + + key_cert->cert = own_cert; + key_cert->key_own_alloc = 1; + + return( 0 ); } #endif /* POLARSSL_X509_CRT_PARSE_C */ @@ -4188,13 +4230,26 @@ void ssl_free( ssl_context *ssl ) } #endif -#if defined(POLARSSL_PK_C) - if( ssl->pk_key_own_alloc ) +#if defined(POLARSSL_X509_CRT_PARSE_C) + if( ssl->key_cert != NULL ) { - pk_free( ssl->pk_key ); - polarssl_free( ssl->pk_key ); + ssl_key_cert *cur = ssl->key_cert, *next; + + while( cur != NULL ) + { + next = cur->next; + + if( cur->key_own_alloc ) + { + pk_free( cur->key ); + polarssl_free( cur->key ); + } + polarssl_free( cur ); + + cur = next; + } } -#endif +#endif /* POLARSSL_X509_CRT_PARSE_C */ #if defined(POLARSSL_SSL_HW_RECORD_ACCEL) if( ssl_hw_record_finish != NULL )