From 22e84de97145e22d2229e1cf9769e4dce1e172c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?= Date: Fri, 10 Jun 2022 09:48:38 +0200 Subject: [PATCH] Improve contract of mbedtls_pk_ec/rsa() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trusting the caller to perform the appropriate check is both risky, and a bit user-unfriendly. Returning NULL on error seems both safer (dereferencing a NULL pointer is more likely to result in a clean crash, while mis-casting a pointer might have deeper, less predictable consequences) and friendlier (the caller can just check the return value for NULL, which is a common idiom). Only add that as an additional way of using the function, for the sake of backwards compatibility. Calls where we know the type of the context for sure (for example because we just set it up) were legal and safe, so they should remain legal without checking the result for NULL, which would be redundant. Signed-off-by: Manuel Pégourié-Gonnard --- include/mbedtls/pk.h | 70 ++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/include/mbedtls/pk.h b/include/mbedtls/pk.h index 126a38967..30f0492e9 100644 --- a/include/mbedtls/pk.h +++ b/include/mbedtls/pk.h @@ -216,35 +216,6 @@ typedef struct typedef void mbedtls_pk_restart_ctx; #endif /* MBEDTLS_ECDSA_C && MBEDTLS_ECP_RESTARTABLE */ -#if defined(MBEDTLS_RSA_C) -/** - * Quick access to an RSA context inside a PK context. - * - * \warning This function can only be used when the type of the context, as - * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_RSA. - * Ensuring that is the caller's responsibility. - */ -static inline mbedtls_rsa_context *mbedtls_pk_rsa( const mbedtls_pk_context pk ) -{ - return( (mbedtls_rsa_context *) (pk).MBEDTLS_PRIVATE(pk_ctx) ); -} -#endif /* MBEDTLS_RSA_C */ - -#if defined(MBEDTLS_ECP_C) -/** - * Quick access to an EC context inside a PK context. - * - * \warning This function can only be used when the type of the context, as - * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_ECKEY, - * #MBEDTLS_PK_ECKEY_DH, or #MBEDTLS_PK_ECDSA. - * Ensuring that is the caller's responsibility. - */ -static inline mbedtls_ecp_keypair *mbedtls_pk_ec( const mbedtls_pk_context pk ) -{ - return( (mbedtls_ecp_keypair *) (pk).MBEDTLS_PRIVATE(pk_ctx) ); -} -#endif /* MBEDTLS_ECP_C */ - #if defined(MBEDTLS_PK_RSA_ALT_SUPPORT) /** * \brief Types for RSA-alt abstraction @@ -738,6 +709,47 @@ const char * mbedtls_pk_get_name( const mbedtls_pk_context *ctx ); */ mbedtls_pk_type_t mbedtls_pk_get_type( const mbedtls_pk_context *ctx ); +#if defined(MBEDTLS_RSA_C) +/** + * Quick access to an RSA context inside a PK context. + * + * \warning This function can only be used when the type of the context, as + * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_RSA. + * Ensuring that is the caller's responsibility. + * Alternatively, you can check whether this function returns NULL. + * + * \return The internal RSA context held by the PK context, or NULL. + */ +static inline mbedtls_rsa_context *mbedtls_pk_rsa( const mbedtls_pk_context pk ) +{ + return( mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_RSA ? + (mbedtls_rsa_context *) (pk).MBEDTLS_PRIVATE(pk_ctx) : + NULL ); +} +#endif /* MBEDTLS_RSA_C */ + +#if defined(MBEDTLS_ECP_C) +/** + * Quick access to an EC context inside a PK context. + * + * \warning This function can only be used when the type of the context, as + * returned by mbedtls_pk_get_type(), is #MBEDTLS_PK_ECKEY, + * #MBEDTLS_PK_ECKEY_DH, or #MBEDTLS_PK_ECDSA. + * Ensuring that is the caller's responsibility. + * Alternatively, you can check whether this function returns NULL. + * + * \return The internal EC context held by the PK context, or NULL. + */ +static inline mbedtls_ecp_keypair *mbedtls_pk_ec( const mbedtls_pk_context pk ) +{ + return( mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_ECKEY || + mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_ECKEY_DH || + mbedtls_pk_get_type( &pk ) == MBEDTLS_PK_ECDSA ? + (mbedtls_ecp_keypair *) (pk).MBEDTLS_PRIVATE(pk_ctx) : + NULL ); +} +#endif /* MBEDTLS_ECP_C */ + #if defined(MBEDTLS_PK_PARSE_C) /** \ingroup pk_module */ /**