Further code optimizations

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/ChangeLog.d/ffdh-tls-1-3.txt b/ChangeLog.d/ffdh-tls-1-3.txt
index d358f9b..c5d07d6 100644
--- a/ChangeLog.d/ffdh-tls-1-3.txt
+++ b/ChangeLog.d/ffdh-tls-1-3.txt
@@ -1,2 +1,6 @@
 Features
    * Add support for FFDH key exchange in TLS 1.3.
+     This is automatically enabled as soon as PSA_WANT_ALG_FFDH
+     and the ephemeral or psk-ephemeral key exchange mode are enabled.
+     By default, all groups are offered; the list of groups can be
+     configured using the existing API function mbedtls_ssl_conf_groups().
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 2d72cde..8ef3723 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1567,7 +1567,7 @@
  * \brief Return PSA EC info for the specified TLS ID.
  *
  * \param tls_id    The TLS ID to look for
- * \param family    If the TLD ID is supported, then proper \c psa_ecc_family_t
+ * \param type      If the TLD ID is supported, then proper \c psa_key_type_t
  *                  value is returned here. Can be NULL.
  * \param bits      If the TLD ID is supported, then proper bit size is returned
  *                  here. Can be NULL.
@@ -1580,7 +1580,7 @@
  *                  simply to check if a specific TLS ID is supported.
  */
 int mbedtls_ssl_get_psa_curve_info_from_tls_id(uint16_t tls_id,
-                                               psa_ecc_family_t *family,
+                                               psa_key_type_t *type,
                                                size_t *bits);
 
 /**
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index c46f041..bee86ca 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -5614,13 +5614,13 @@
 };
 
 int mbedtls_ssl_get_psa_curve_info_from_tls_id(uint16_t tls_id,
-                                               psa_ecc_family_t *family,
+                                               psa_key_type_t *type,
                                                size_t *bits)
 {
     for (int i = 0; tls_id_match_table[i].tls_id != 0; i++) {
         if (tls_id_match_table[i].tls_id == tls_id) {
-            if (family != NULL) {
-                *family = tls_id_match_table[i].psa_family;
+            if (type != NULL) {
+                *type = PSA_KEY_TYPE_ECC_KEY_PAIR(tls_id_match_table[i].psa_family);
             }
             if (bits != NULL) {
                 *bits = tls_id_match_table[i].bits;
diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c
index 775ab9b..df8af0d 100644
--- a/library/ssl_tls12_client.c
+++ b/library/ssl_tls12_client.c
@@ -1714,7 +1714,7 @@
     uint16_t tls_id;
     uint8_t ecpoint_len;
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
-    psa_ecc_family_t ec_psa_family = 0;
+    psa_key_type_t key_type = 0;
     size_t ec_bits = 0;
 
     /*
@@ -1751,11 +1751,11 @@
     }
 
     /* Convert EC's TLS ID to PSA key type. */
-    if (mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &ec_psa_family,
+    if (mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &key_type,
                                                    &ec_bits) == PSA_ERROR_NOT_SUPPORTED) {
         return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
     }
-    handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
+    handshake->ecdh_psa_type = key_type;
     handshake->ecdh_bits = ec_bits;
 
     /* Keep a copy of the peer's public key */
@@ -2014,7 +2014,7 @@
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     uint16_t tls_id = 0;
-    psa_ecc_family_t ecc_family;
+    psa_key_type_t key_type = 0;
     mbedtls_ecp_group_id grp_id = mbedtls_pk_get_group_id(peer_pk);
 
     if (mbedtls_ssl_check_curve(ssl, grp_id) != 0) {
@@ -2031,10 +2031,10 @@
 
     /* If the above conversion to TLS ID was fine, then also this one will be,
        so there is no need to check the return value here */
-    mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &ecc_family,
+    mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &key_type,
                                                &ssl->handshake->ecdh_bits);
 
-    ssl->handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ecc_family);
+    ssl->handshake->ecdh_psa_type = key_type;
 
     /* Store peer's public key in psa format. */
 #if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 3f2aa44..3234b2d 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -2594,7 +2594,7 @@
         PSA_KEY_EXPORT_ECC_KEY_PAIR_MAX_SIZE(PSA_VENDOR_ECC_MAX_CURVE_BITS)];
     psa_key_attributes_t key_attributes = PSA_KEY_ATTRIBUTES_INIT;
     uint16_t tls_id = 0;
-    psa_ecc_family_t ecc_family;
+    psa_key_type_t key_type = 0;
     size_t key_len;
     mbedtls_pk_context *pk;
     mbedtls_ecp_group_id grp_id;
@@ -2649,10 +2649,10 @@
 
             /* If the above conversion to TLS ID was fine, then also this one will
                be, so there is no need to check the return value here */
-            mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &ecc_family,
+            mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id, &key_type,
                                                        &ssl->handshake->ecdh_bits);
 
-            ssl->handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ecc_family);
+            ssl->handshake->ecdh_psa_type = key_type;
 
             key_attributes = psa_key_attributes_init();
             psa_set_key_usage_flags(&key_attributes, PSA_KEY_USAGE_DERIVE);
@@ -2961,19 +2961,19 @@
         const size_t header_size = 4; // curve_type(1), namedcurve(2),
                                       // data length(1)
         const size_t data_length_size = 1;
-        psa_ecc_family_t ec_psa_family = 0;
+        psa_key_type_t key_type = 0;
         size_t ec_bits = 0;
 
         MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based ECDH computation."));
 
         /* Convert EC's TLS ID to PSA key type. */
         if (mbedtls_ssl_get_psa_curve_info_from_tls_id(*curr_tls_id,
-                                                       &ec_psa_family,
+                                                       &key_type,
                                                        &ec_bits) == PSA_ERROR_NOT_SUPPORTED) {
             MBEDTLS_SSL_DEBUG_MSG(1, ("Invalid ecc group parse."));
             return MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER;
         }
-        handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
+        handshake->ecdh_psa_type = key_type;
         handshake->ecdh_bits = ec_bits;
 
         key_attributes = psa_key_attributes_init();
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 42cabf5..030135b 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1513,7 +1513,7 @@
     return 0;
 }
 
-static psa_key_type_t mbedtls_psa_parse_tls_ffdh_group(
+static psa_status_t  mbedtls_ssl_get_psa_ffdh_info_from_tls_id(
     uint16_t tls_ecc_grp_reg_id, size_t *bits, psa_key_type_t *key_type)
 {
     switch (tls_ecc_grp_reg_id) {
@@ -1556,28 +1556,21 @@
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
     size_t bits = 0;
     psa_key_type_t key_type = 0;
+    psa_algorithm_t alg = 0;
     size_t buf_size = (size_t) (end - buf);
 
-
     MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based ECDH/FFDH computation."));
 
     /* Convert EC's TLS ID to PSA key type. */
 #if defined(PSA_WANT_ALG_ECDH)
-    psa_ecc_family_t ec_psa_family = 0;
     if (mbedtls_ssl_get_psa_curve_info_from_tls_id(
-            named_group, &ec_psa_family, &bits) == PSA_SUCCESS) {
-        key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
+            named_group, &key_type, &bits) == PSA_SUCCESS) {
+        alg = PSA_ALG_ECDH;
     }
 #endif
 #if defined(PSA_WANT_ALG_FFDH)
-    if (mbedtls_psa_parse_tls_ffdh_group(named_group, &bits, &key_type) == PSA_SUCCESS) {
-        if (PSA_KEY_TYPE_IS_DH(key_type)) {
-            if (buf_size < PSA_BITS_TO_BYTES(bits)) {
-
-                return MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
-            }
-            buf_size = PSA_BITS_TO_BYTES(bits);
-        }
+    if (mbedtls_ssl_get_psa_ffdh_info_from_tls_id(named_group, &bits, &key_type) == PSA_SUCCESS) {
+        alg = PSA_ALG_FFDH;
     }
 #endif
 
@@ -1585,22 +1578,17 @@
         return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
     }
 
+    if (buf_size < PSA_BITS_TO_BYTES(bits)) {
+
+        return MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
+    }
+
     handshake->ecdh_psa_type = key_type;
     ssl->handshake->ecdh_bits = bits;
 
     key_attributes = psa_key_attributes_init();
     psa_set_key_usage_flags(&key_attributes, PSA_KEY_USAGE_DERIVE);
-
-    if (PSA_KEY_TYPE_IS_ECC(key_type)) {
-#if defined(PSA_WANT_ALG_ECDH)
-        psa_set_key_algorithm(&key_attributes, PSA_ALG_ECDH);
-#endif
-    } else {
-#if defined(PSA_WANT_ALG_FFDH)
-        psa_set_key_algorithm(&key_attributes, PSA_ALG_FFDH);
-#endif
-    }
-
+    psa_set_key_algorithm(&key_attributes, alg);
     psa_set_key_type(&key_attributes, handshake->ecdh_psa_type);
     psa_set_key_bits(&key_attributes, handshake->ecdh_bits);
 
@@ -1623,7 +1611,6 @@
         ret = PSA_TO_MBEDTLS_ERR(status);
         MBEDTLS_SSL_DEBUG_RET(1, "psa_export_public_key", ret);
         return ret;
-
     }
 
     *out_len = own_pubkey_len;
diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c
index dcf3087..eb47af1 100644
--- a/programs/ssl/ssl_client2.c
+++ b/programs/ssl/ssl_client2.c
@@ -466,10 +466,6 @@
     USAGE_SERIALIZATION                                                       \
     " acceptable ciphersuite names:\n"
 
-#define ALPN_LIST_SIZE    10
-#define CURVE_LIST_SIZE   25
-#define SIG_ALG_LIST_SIZE  5
-
 /*
  * global options
  */
@@ -1530,7 +1526,7 @@
                          curve_cur++) {
                         mbedtls_printf("%s ", curve_cur->name);
                     }
-                    uint16_t *supported_ffdh_group = mbedtls_ssl_ffdh_supported_groups();
+                    const uint16_t *supported_ffdh_group = mbedtls_ssl_ffdh_supported_groups();
                     while (*supported_ffdh_group != 0) {
                         mbedtls_printf("%s ",
                                        mbedtls_ssl_ffdh_name_from_group(*supported_ffdh_group));
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 9919e08..1986b35 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -587,10 +587,6 @@
     USAGE_SERIALIZATION                                                       \
     " acceptable ciphersuite names:\n"
 
-#define ALPN_LIST_SIZE    10
-#define CURVE_LIST_SIZE   25
-#define SIG_ALG_LIST_SIZE 5
-
 #define PUT_UINT64_BE(out_be, in_le, i)                                   \
     {                                                                       \
         (out_be)[(i) + 0] = (unsigned char) (((in_le) >> 56) & 0xFF);    \
@@ -2423,7 +2419,7 @@
                          curve_cur->grp_id != MBEDTLS_ECP_DP_NONE;
                          curve_cur++) {
                         mbedtls_printf("%s ", curve_cur->name);
-                        uint16_t *supported_ffdh_group = mbedtls_ssl_ffdh_supported_groups();
+                        const uint16_t *supported_ffdh_group = mbedtls_ssl_ffdh_supported_groups();
                         while (*supported_ffdh_group != 0) {
                             mbedtls_printf("%s ",
                                            mbedtls_ssl_ffdh_name_from_group(*supported_ffdh_group));
diff --git a/programs/ssl/ssl_test_lib.c b/programs/ssl/ssl_test_lib.c
index ea422e9..26824c2 100644
--- a/programs/ssl/ssl_test_lib.c
+++ b/programs/ssl/ssl_test_lib.c
@@ -465,9 +465,9 @@
     return 0;
 }
 
-uint16_t *mbedtls_ssl_ffdh_supported_groups(void)
+const uint16_t *mbedtls_ssl_ffdh_supported_groups(void)
 {
-    static uint16_t ffdh_groups[] = {
+    static const uint16_t ffdh_groups[] = {
         MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE2048,
         MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE3072,
         MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE4096,
diff --git a/programs/ssl/ssl_test_lib.h b/programs/ssl/ssl_test_lib.h
index 5f9dbdd..c2afc96 100644
--- a/programs/ssl/ssl_test_lib.h
+++ b/programs/ssl/ssl_test_lib.h
@@ -80,6 +80,10 @@
 
 #include "../test/query_config.h"
 
+#define ALPN_LIST_SIZE    10
+#define CURVE_LIST_SIZE   25
+#define SIG_ALG_LIST_SIZE  5
+
 typedef struct eap_tls_keys {
     unsigned char master_secret[48];
     unsigned char randbytes[64];
@@ -309,7 +313,7 @@
 
 /* Helper functions for FFDH groups. */
 uint16_t mbedtls_ssl_ffdh_group_from_name(const char *name);
-uint16_t *mbedtls_ssl_ffdh_supported_groups(void);
+const uint16_t *mbedtls_ssl_ffdh_supported_groups(void);
 
 #endif /* MBEDTLS_SSL_TEST_IMPOSSIBLE conditions: else */
 #endif /* MBEDTLS_PROGRAMS_SSL_SSL_TEST_LIB_H */
diff --git a/tests/include/test/ssl_helpers.h b/tests/include/test/ssl_helpers.h
index 572b6cb..e7bfec9 100644
--- a/tests/include/test/ssl_helpers.h
+++ b/tests/include/test/ssl_helpers.h
@@ -602,8 +602,8 @@
     TEST_EQUAL(mbedtls_ssl_get_tls_id_from_ecp_group_id(group_id_),      \
                tls_id_);                                                 \
     TEST_EQUAL(mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id_,       \
-                                                          &psa_family, &psa_bits), PSA_SUCCESS);                \
-    TEST_EQUAL(psa_family_, psa_family);                                 \
+                                                          &psa_type, &psa_bits), PSA_SUCCESS);                \
+    TEST_EQUAL(psa_family_, PSA_KEY_TYPE_ECC_GET_FAMILY(psa_type));    \
     TEST_EQUAL(psa_bits_, psa_bits);
 
 #define TEST_UNAVAILABLE_ECC(tls_id_, group_id_, psa_family_, psa_bits_) \
@@ -612,7 +612,7 @@
     TEST_EQUAL(mbedtls_ssl_get_tls_id_from_ecp_group_id(group_id_),      \
                0);                                                       \
     TEST_EQUAL(mbedtls_ssl_get_psa_curve_info_from_tls_id(tls_id_,       \
-                                                          &psa_family, &psa_bits), \
+                                                          &psa_type, &psa_bits), \
                PSA_ERROR_NOT_SUPPORTED);
 
 #endif /* MBEDTLS_SSL_TLS_C */
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 6f9e544..fd10595 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -3591,7 +3591,7 @@
 /* BEGIN_CASE */
 void elliptic_curve_get_properties()
 {
-    psa_ecc_family_t psa_family;
+    psa_key_type_t psa_type = 0;
     size_t psa_bits;
 
     MD_OR_USE_PSA_INIT();