diff --git a/include/mbedtls/pk.h b/include/mbedtls/pk.h index 126a38967..30f0492e9 100644 --- a/include/mbedtls/pk.h +++ b/include/mbedtls/pk.h @@ -216,35 +216,6 @@ typedef struct typedef void mbedtls_pk_restart_ctx; #endif /* MBEDTLS_ECDSA_C && MBEDTLS_ECP_RESTARTABLE */ -#if defined(MBEDTLS_RSA_C) -/** - * Quick access to an RSA context inside a PK context. - * - * \warning This function can only be used when the type of the context, as - * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_RSA. - * Ensuring that is the caller's responsibility. - */ -static inline mbedtls_rsa_context *mbedtls_pk_rsa( const mbedtls_pk_context pk ) -{ - return( (mbedtls_rsa_context *) (pk).MBEDTLS_PRIVATE(pk_ctx) ); -} -#endif /* MBEDTLS_RSA_C */ - -#if defined(MBEDTLS_ECP_C) -/** - * Quick access to an EC context inside a PK context. - * - * \warning This function can only be used when the type of the context, as - * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_ECKEY, - * #MBEDTLS_PK_ECKEY_DH, or #MBEDTLS_PK_ECDSA. - * Ensuring that is the caller's responsibility. - */ -static inline mbedtls_ecp_keypair *mbedtls_pk_ec( const mbedtls_pk_context pk ) -{ - return( (mbedtls_ecp_keypair *) (pk).MBEDTLS_PRIVATE(pk_ctx) ); -} -#endif /* MBEDTLS_ECP_C */ - #if defined(MBEDTLS_PK_RSA_ALT_SUPPORT) /** * \brief Types for RSA-alt abstraction @@ -738,6 +709,47 @@ const char * mbedtls_pk_get_name( const mbedtls_pk_context *ctx ); */ mbedtls_pk_type_t mbedtls_pk_get_type( const mbedtls_pk_context *ctx ); +#if defined(MBEDTLS_RSA_C) +/** + * Quick access to an RSA context inside a PK context. + * + * \warning This function can only be used when the type of the context, as + * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_RSA. + * Ensuring that is the caller's responsibility. + * Alternatively, you can check whether this function returns NULL. + * + * \return The internal RSA context held by the PK context, or NULL. + */ +static inline mbedtls_rsa_context *mbedtls_pk_rsa( const mbedtls_pk_context pk ) +{ + return( mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_RSA ? + (mbedtls_rsa_context *) (pk).MBEDTLS_PRIVATE(pk_ctx) : + NULL ); +} +#endif /* MBEDTLS_RSA_C */ + +#if defined(MBEDTLS_ECP_C) +/** + * Quick access to an EC context inside a PK context. + * + * \warning This function can only be used when the type of the context, as + * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_ECKEY, + * #MBEDTLS_PK_ECKEY_DH, or #MBEDTLS_PK_ECDSA. + * Ensuring that is the caller's responsibility. + * Alternatively, you can check whether this function returns NULL. + * + * \return The internal EC context held by the PK context, or NULL. + */ +static inline mbedtls_ecp_keypair *mbedtls_pk_ec( const mbedtls_pk_context pk ) +{ + return( mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_ECKEY || + mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_ECKEY_DH || + mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_ECDSA ? + (mbedtls_ecp_keypair *) (pk).MBEDTLS_PRIVATE(pk_ctx) : + NULL ); +} +#endif /* MBEDTLS_ECP_C */ + #if defined(MBEDTLS_PK_PARSE_C) /** \ingroup pk_module */ /**