Further code optimizations
Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 2d72cde..8ef3723 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1567,7 +1567,7 @@
* \brief Return PSA EC info for the specified TLS ID.
*
* \param tls_id The TLS ID to look for
- * \param family If the TLD ID is supported, then proper \c psa_ecc_family_t
+ * \param type If the TLD ID is supported, then proper \c psa_key_type_t
* value is returned here. Can be NULL.
* \param bits If the TLD ID is supported, then proper bit size is returned
* here. Can be NULL.
@@ -1580,7 +1580,7 @@
* simply to check if a specific TLS ID is supported.
*/
int mbedtls_ssl_get_psa_curve_info_from_tls_id(uint16_t tls_id,
- psa_ecc_family_t *family,
+ psa_key_type_t *type,
size_t *bits);
/**
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index c46f041..bee86ca 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -5614,13 +5614,13 @@
};
int mbedtls_ssl_get_psa_curve_info_from_tls_id(uint16_t tls_id,
- psa_ecc_family_t *family,
+ psa_key_type_t *type,
size_t *bits)
{
for (int i = 0; tls_id_match_table[i].tls_id != 0; i++) {
if (tls_id_match_table[i].tls_id == tls_id) {
- if (family != NULL) {
- *family = tls_id_match_table[i].psa_family;
+ if (type != NULL) {
+ *type = PSA_KEY_TYPE_ECC_KEY_PAIR(tls_id_match_table[i].psa_family);
}
if (bits != NULL) {
*bits = tls_id_match_table[i].bits;
diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c
index 775ab9b..df8af0d 100644
--- a/library/ssl_tls12_client.c
+++ b/library/ssl_tls12_client.c
@@ -1714,7 +1714,7 @@
uint16_t tls_id;
uint8_t ecpoint_len;
mbedtls_ssl_handshake_params *handshake = ssl->handshake;
- psa_ecc_family_t ec_psa_family = 0;
+ psa_key_type_t key_type = 0;
size_t ec_bits = 0;
/*
@@ -1751,11 +1751,11 @@
}
/* Convert EC's TLS ID to PSA key type. */
- if (mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &ec_psa_family,
+ if (mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &key_type,
&ec_bits) == PSA_ERROR_NOT_SUPPORTED) {
return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
}
- handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
+ handshake->ecdh_psa_type = key_type;
handshake->ecdh_bits = ec_bits;
/* Keep a copy of the peer's public key */
@@ -2014,7 +2014,7 @@
#if defined(MBEDTLS_USE_PSA_CRYPTO)
uint16_t tls_id = 0;
- psa_ecc_family_t ecc_family;
+ psa_key_type_t key_type = 0;
mbedtls_ecp_group_id grp_id = mbedtls_pk_get_group_id(peer_pk);
if (mbedtls_ssl_check_curve(ssl, grp_id) != 0) {
@@ -2031,10 +2031,10 @@
/* If the above conversion to TLS ID was fine, then also this one will be,
so there is no need to check the return value here */
- mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &ecc_family,
+ mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &key_type,
&ssl->handshake->ecdh_bits);
- ssl->handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ecc_family);
+ ssl->handshake->ecdh_psa_type = key_type;
/* Store peer's public key in psa format. */
#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 3f2aa44..3234b2d 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -2594,7 +2594,7 @@
PSA_KEY_EXPORT_ECC_KEY_PAIR_MAX_SIZE(PSA_VENDOR_ECC_MAX_CURVE_BITS)];
psa_key_attributes_t key_attributes = PSA_KEY_ATTRIBUTES_INIT;
uint16_t tls_id = 0;
- psa_ecc_family_t ecc_family;
+ psa_key_type_t key_type = 0;
size_t key_len;
mbedtls_pk_context *pk;
mbedtls_ecp_group_id grp_id;
@@ -2649,10 +2649,10 @@
/* If the above conversion to TLS ID was fine, then also this one will
be, so there is no need to check the return value here */
- mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &ecc_family,
+ mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &key_type,
&ssl->handshake->ecdh_bits);
- ssl->handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ecc_family);
+ ssl->handshake->ecdh_psa_type = key_type;
key_attributes = psa_key_attributes_init();
psa_set_key_usage_flags(&key_attributes, PSA_KEY_USAGE_DERIVE);
@@ -2961,19 +2961,19 @@
const size_t header_size = 4; // curve_type(1), namedcurve(2),
// data length(1)
const size_t data_length_size = 1;
- psa_ecc_family_t ec_psa_family = 0;
+ psa_key_type_t key_type = 0;
size_t ec_bits = 0;
MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based ECDH computation."));
/* Convert EC's TLS ID to PSA key type. */
if (mbedtls_ssl_get_psa_curve_info_from_tls_id(*curr_tls_id,
- &ec_psa_family,
+ &key_type,
&ec_bits) == PSA_ERROR_NOT_SUPPORTED) {
MBEDTLS_SSL_DEBUG_MSG(1, ("Invalid ecc group parse."));
return MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER;
}
- handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
+ handshake->ecdh_psa_type = key_type;
handshake->ecdh_bits = ec_bits;
key_attributes = psa_key_attributes_init();
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 42cabf5..030135b 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1513,7 +1513,7 @@
return 0;
}
-static psa_key_type_t mbedtls_psa_parse_tls_ffdh_group(
+static psa_status_t mbedtls_ssl_get_psa_ffdh_info_from_tls_id(
uint16_t tls_ecc_grp_reg_id, size_t *bits, psa_key_type_t *key_type)
{
switch (tls_ecc_grp_reg_id) {
@@ -1556,28 +1556,21 @@
mbedtls_ssl_handshake_params *handshake = ssl->handshake;
size_t bits = 0;
psa_key_type_t key_type = 0;
+ psa_algorithm_t alg = 0;
size_t buf_size = (size_t) (end - buf);
-
MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based ECDH/FFDH computation."));
/* Convert EC's TLS ID to PSA key type. */
#if defined(PSA_WANT_ALG_ECDH)
- psa_ecc_family_t ec_psa_family = 0;
if (mbedtls_ssl_get_psa_curve_info_from_tls_id(
- named_group, &ec_psa_family, &bits) == PSA_SUCCESS) {
- key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
+ named_group, &key_type, &bits) == PSA_SUCCESS) {
+ alg = PSA_ALG_ECDH;
}
#endif
#if defined(PSA_WANT_ALG_FFDH)
- if (mbedtls_psa_parse_tls_ffdh_group(named_group, &bits, &key_type) == PSA_SUCCESS) {
- if (PSA_KEY_TYPE_IS_DH(key_type)) {
- if (buf_size < PSA_BITS_TO_BYTES(bits)) {
-
- return MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
- }
- buf_size = PSA_BITS_TO_BYTES(bits);
- }
+ if (mbedtls_ssl_get_psa_ffdh_info_from_tls_id(named_group, &bits, &key_type) == PSA_SUCCESS) {
+ alg = PSA_ALG_FFDH;
}
#endif
@@ -1585,22 +1578,17 @@
return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
}
+ if (buf_size < PSA_BITS_TO_BYTES(bits)) {
+
+ return MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
+ }
+
handshake->ecdh_psa_type = key_type;
ssl->handshake->ecdh_bits = bits;
key_attributes = psa_key_attributes_init();
psa_set_key_usage_flags(&key_attributes, PSA_KEY_USAGE_DERIVE);
-
- if (PSA_KEY_TYPE_IS_ECC(key_type)) {
-#if defined(PSA_WANT_ALG_ECDH)
- psa_set_key_algorithm(&key_attributes, PSA_ALG_ECDH);
-#endif
- } else {
-#if defined(PSA_WANT_ALG_FFDH)
- psa_set_key_algorithm(&key_attributes, PSA_ALG_FFDH);
-#endif
- }
-
+ psa_set_key_algorithm(&key_attributes, alg);
psa_set_key_type(&key_attributes, handshake->ecdh_psa_type);
psa_set_key_bits(&key_attributes, handshake->ecdh_bits);
@@ -1623,7 +1611,6 @@
ret = PSA_TO_MBEDTLS_ERR(status);
MBEDTLS_SSL_DEBUG_RET(1, "psa_export_public_key", ret);
return ret;
-
}
*out_len = own_pubkey_len;