diff --git a/library/ssl_client.c b/library/ssl_client.c index 4e42d00f3..08b3de802 100644 --- a/library/ssl_client.c +++ b/library/ssl_client.c @@ -966,24 +966,6 @@ int mbedtls_ssl_write_client_hello(mbedtls_ssl_context *ssl) #if defined(MBEDTLS_SSL_EARLY_DATA) if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED) { - psa_algorithm_t hash_alg = PSA_ALG_NONE; - const unsigned char *psk; - size_t psk_len; - MBEDTLS_SSL_DEBUG_MSG(1, ("in generate early keys")); - - if ((ret = mbedtls_ssl_tls13_ticket_get_psk( - ssl, &hash_alg, &psk, &psk_len)) - != 0) { - MBEDTLS_SSL_DEBUG_RET( - 1, "mbedtls_ssl_tls13_ticket_get_psk", ret); - goto cleanup; - } - - if ((ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len)) != 0) { - MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret); - goto cleanup; - } - /* Start the TLS 1.3 key schedule: * Set the PSK and derive early secret. */ diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 376f6cf9a..86f5c0b55 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1676,8 +1676,6 @@ int mbedtls_ssl_set_session(mbedtls_ssl_context *ssl, const mbedtls_ssl_session session->ciphersuite)); return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; } - ssl->handshake->ciphersuite_info = ciphersuite_info; - ssl->handshake->key_exchange_mode = MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL; } #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */ diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c index 6f91fb27b..874f2439f 100644 --- a/library/ssl_tls13_client.c +++ b/library/ssl_tls13_client.c @@ -893,11 +893,16 @@ int mbedtls_ssl_tls13_write_identities_of_pre_shared_key_ext( int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; int configured_psk_count = 0; unsigned char *p = buf; - psa_algorithm_t hash_alg; + psa_algorithm_t hash_alg = PSA_ALG_NONE; const unsigned char *identity; size_t identity_len; size_t l_binders_len = 0; size_t output_len; +#if defined(MBEDTLS_SSL_EARLY_DATA) + const unsigned char *psk; + size_t psk_len; + const mbedtls_ssl_ciphersuite_t *ciphersuite_info; +#endif *out_len = 0; *binders_len = 0; @@ -962,6 +967,30 @@ int mbedtls_ssl_tls13_write_identities_of_pre_shared_key_ext( p += output_len; l_binders_len += 1 + PSA_HASH_LENGTH(hash_alg); + +#if defined(MBEDTLS_SSL_EARLY_DATA) + MBEDTLS_SSL_DEBUG_MSG( + 1, ("Set hs psk for early data when writing the first psk")); + + ret = ssl_tls13_ticket_get_psk(ssl, &hash_alg, &psk, &psk_len); + if (ret != 0) { + MBEDTLS_SSL_DEBUG_RET( + 1, "mbedtls_ssl_tls13_ticket_get_psk", ret); + return ret; + } + + ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len); + if (ret != 0) { + MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret); + return ret; + } + + ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( + ssl->session_negotiate->ciphersuite); + ssl->handshake->ciphersuite_info = ciphersuite_info; + ssl->handshake->key_exchange_mode = + MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL; +#endif /* MBEDTLS_SSL_EARLY_DATA */ } #endif /* MBEDTLS_SSL_SESSION_TICKETS */