Add support for FFDH in TLS 1.3

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index e347853..8f14349 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -33,6 +33,7 @@
 #include "ssl_client.h"
 #include "ssl_tls13_keys.h"
 #include "ssl_debug_helpers.h"
+#include "mbedtls/dhm.h"
 
 #define PSA_TO_MBEDTLS_ERR(status) PSA_TO_MBEDTLS_ERR_LIST(status,   \
                                                            psa_to_ssl_errors,             \
@@ -185,8 +186,9 @@
         return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
     }
 
-#if defined(PSA_WANT_ALG_ECDH)
-    if (mbedtls_ssl_tls13_named_group_is_ecdhe(group_id)) {
+#if defined(PSA_WANT_ALG_ECDH) || defined(PSA_WANT_ALG_FFDH)
+    if (mbedtls_ssl_tls13_named_group_is_ecdhe(group_id) ||
+        mbedtls_ssl_tls13_named_group_is_dhe(group_id)) {
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
@@ -201,7 +203,7 @@
         ssl->handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
         return 0;
     } else
-#endif /* PSA_WANT_ALG_ECDH */
+#endif /* PSA_WANT_ALG_ECDH || PSA_WANT_ALG_FFDH */
     if (0 /* other KEMs? */) {
         /* Do something */
     }
@@ -220,13 +222,20 @@
     int ret = MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
 
 
-#if defined(PSA_WANT_ALG_ECDH)
+#if defined(PSA_WANT_ALG_ECDH) || defined(PSA_WANT_ALG_FFDH)
     const uint16_t *group_list = mbedtls_ssl_get_groups(ssl);
     /* Pick first available ECDHE group compatible with TLS 1.3 */
     if (group_list == NULL) {
         return MBEDTLS_ERR_SSL_BAD_CONFIG;
     }
-
+#if defined(PSA_WANT_ALG_FFDH)
+    if (*group_list >= MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE2048 &&
+        *group_list <= MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE8192) {
+        *group_id = *group_list;
+        return 0;
+    }
+#endif /* PSA_WANT_ALG_FFDH */
+#if defined(PSA_WANT_ALG_ECDH)
     for (; *group_list != 0; group_list++) {
         if ((mbedtls_ssl_get_psa_curve_info_from_tls_id(
                  *group_list, NULL, NULL) == PSA_SUCCESS) &&
@@ -235,10 +244,11 @@
             return 0;
         }
     }
+#endif /* PSA_WANT_ALG_ECDH */
 #else
     ((void) ssl);
     ((void) group_id);
-#endif /* PSA_WANT_ALG_ECDH */
+#endif /* PSA_WANT_ALG_ECDH || PSA_WANT_ALG_FFDH */
 
     /*
      * Add DHE named groups here.
@@ -302,8 +312,9 @@
      * only one key share entry is allowed.
      */
     client_shares = p;
-#if defined(PSA_WANT_ALG_ECDH)
-    if (mbedtls_ssl_tls13_named_group_is_ecdhe(group_id)) {
+#if defined(PSA_WANT_ALG_ECDH) || defined(PSA_WANT_ALG_FFDH)
+    if (mbedtls_ssl_tls13_named_group_is_ecdhe(group_id) ||
+        mbedtls_ssl_tls13_named_group_is_dhe(group_id)) {
         /* Pointer to group */
         unsigned char *group = p;
         /* Length of key_exchange */
@@ -315,8 +326,18 @@
          */
         MBEDTLS_SSL_CHK_BUF_PTR(p, end, 4);
         p += 4;
-        ret = mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
-            ssl, group_id, p, end, &key_exchange_len);
+#if defined(PSA_WANT_ALG_ECDH)
+        if (mbedtls_ssl_tls13_named_group_is_ecdhe(group_id)) {
+            ret = mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
+                ssl, group_id, p, end, &key_exchange_len);
+        }
+#endif /* PSA_WANT_ALG_ECDH */
+#if defined(PSA_WANT_ALG_FFDH)
+        if (mbedtls_ssl_tls13_named_group_is_dhe(group_id)) {
+            ret = mbedtls_ssl_tls13_generate_and_write_dhe_key_exchange(
+                ssl, group_id, p, end, &key_exchange_len);
+        }
+#endif /* PSA_WANT_ALG_FFDH */
         p += key_exchange_len;
         if (ret != 0) {
             return ret;
@@ -327,7 +348,7 @@
         /* Write key_exchange_length */
         MBEDTLS_PUT_UINT16_BE(key_exchange_len, group, 2);
     } else
-#endif /* PSA_WANT_ALG_ECDH */
+#endif /* PSA_WANT_ALG_ECDH || PSA_WANT_ALG_FFDH */
     if (0 /* other KEMs? */) {
         /* Do something */
     } else {
@@ -404,15 +425,18 @@
      * then the client MUST abort the handshake with an "illegal_parameter" alert.
      */
     for (; *group_list != 0; group_list++) {
-        if ((mbedtls_ssl_get_psa_curve_info_from_tls_id(
-                 *group_list, NULL, NULL) == PSA_ERROR_NOT_SUPPORTED) ||
-            *group_list != selected_group) {
-            continue;
+        if (mbedtls_ssl_tls13_named_group_is_ecdhe(*group_list)) {
+            if ((mbedtls_ssl_get_psa_curve_info_from_tls_id(
+                     *group_list, NULL, NULL) == PSA_ERROR_NOT_SUPPORTED) ||
+                *group_list != selected_group) {
+                found = 1;
+                break;
+            }
         }
-
-        /* We found a match */
-        found = 1;
-        break;
+        if (mbedtls_ssl_tls13_named_group_is_dhe(*group_list)) {
+            found = 1;
+            break;
+        }
     }
 
     /* Client MUST verify that the selected_group field does not
@@ -482,24 +506,34 @@
         return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
     }
 
+#if defined(PSA_WANT_ALG_ECDH) || defined(PSA_WANT_ALG_FFDH)
+    if (mbedtls_ssl_tls13_named_group_is_ecdhe(group) ||
+        mbedtls_ssl_tls13_named_group_is_dhe(group)) {
 #if defined(PSA_WANT_ALG_ECDH)
-    if (mbedtls_ssl_tls13_named_group_is_ecdhe(group)) {
-        if (mbedtls_ssl_get_psa_curve_info_from_tls_id(group, NULL, NULL)
-            == PSA_ERROR_NOT_SUPPORTED) {
-            MBEDTLS_SSL_DEBUG_MSG(1, ("Invalid TLS curve group id"));
-            return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
+        if (mbedtls_ssl_tls13_named_group_is_ecdhe(group)) {
+            if (mbedtls_ssl_get_psa_curve_info_from_tls_id(group, NULL, NULL)
+                == PSA_ERROR_NOT_SUPPORTED) {
+                MBEDTLS_SSL_DEBUG_MSG(1, ("Invalid TLS curve group id"));
+                return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
+            }
+
+            MBEDTLS_SSL_DEBUG_MSG(
+                2,
+                ("ECDH curve: %s", mbedtls_ssl_get_curve_name_from_tls_id(group)));
         }
-
-        MBEDTLS_SSL_DEBUG_MSG(
-            2,
-            ("ECDH curve: %s", mbedtls_ssl_get_curve_name_from_tls_id(group)));
-
+#endif /* PSA_WANT_ALG_ECDH */
+#if defined(PSA_WANT_ALG_FFDH)
+        if (mbedtls_ssl_tls13_named_group_is_dhe(group)) {
+            MBEDTLS_SSL_DEBUG_MSG(2,
+                                  ("DHE group name: %s", mbedtls_ssl_ffdh_name_from_group(group)));
+        }
+#endif /* PSA_WANT_ALG_FFDH */
         ret = mbedtls_ssl_tls13_read_public_ecdhe_share(ssl, p, end - p);
         if (ret != 0) {
             return ret;
         }
     } else
-#endif /* PSA_WANT_ALG_ECDH */
+#endif /* PSA_WANT_ALG_ECDH || PSA_WANT_ALG_FFDH */
     if (0 /* other KEMs? */) {
         /* Do something */
     } else {
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index a00785b..a7fddf9 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1571,6 +1571,68 @@
 }
 #endif /* PSA_WANT_ALG_ECDH */
 
+#if defined(PSA_WANT_ALG_FFDH)
+int mbedtls_ssl_tls13_generate_and_write_dhe_key_exchange(
+    mbedtls_ssl_context *ssl,
+    uint16_t named_group,
+    unsigned char *buf,
+    unsigned char *end,
+    size_t *out_len)
+{
+    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
+    int ret = MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
+    psa_key_attributes_t key_attributes;
+    size_t own_pubkey_len;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+    size_t ffdh_bits = 0;
+
+    MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based DHE computation."));
+
+    /* Convert DHE group to PSA key type. */
+    if ((handshake->ecdh_psa_type =
+             mbedtls_psa_parse_tls_ffdh_group(named_group, &ffdh_bits)) == 0) {
+        return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
+    }
+
+    if ((size_t) (end - buf) < PSA_BITS_TO_BYTES(ffdh_bits)) {
+        ret = MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
+        return ret;
+    }
+
+    ssl->handshake->ecdh_bits = ffdh_bits;
+
+    key_attributes = psa_key_attributes_init();
+    psa_set_key_usage_flags(&key_attributes, PSA_KEY_USAGE_DERIVE);
+    psa_set_key_algorithm(&key_attributes, PSA_ALG_FFDH);
+    psa_set_key_type(&key_attributes, handshake->ecdh_psa_type);
+    psa_set_key_bits(&key_attributes, handshake->ecdh_bits);
+
+    /* Generate FFDH private key. */
+    status = psa_generate_key(&key_attributes,
+                              &handshake->ecdh_psa_privkey);
+    if (status != PSA_SUCCESS) {
+        ret = psa_ssl_status_to_mbedtls(status);
+        MBEDTLS_SSL_DEBUG_RET(1, "psa_generate_key", ret);
+        return ret;
+
+    }
+
+    /* Export the public part of the FFDH private key from PSA. */
+    status = psa_export_public_key(handshake->ecdh_psa_privkey,
+                                   buf, PSA_BITS_TO_BYTES(ffdh_bits),
+                                   &own_pubkey_len);
+    if (status != PSA_SUCCESS) {
+        ret = psa_ssl_status_to_mbedtls(status);
+        MBEDTLS_SSL_DEBUG_RET(1, "psa_export_public_key", ret);
+        return ret;
+    }
+
+    *out_len = own_pubkey_len;
+
+    return 0;
+}
+#endif /* PSA_WANT_ALG_FFDH */
+
 /* RFC 8446 section 4.2
  *
  * If an implementation receives an extension which it recognizes and which is
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index 46caa45..c69078d 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -1484,8 +1484,15 @@
      * are derived in the handshake secret derivation stage.
      */
     if (mbedtls_ssl_tls13_key_exchange_mode_with_ephemeral(ssl)) {
-        if (mbedtls_ssl_tls13_named_group_is_ecdhe(handshake->offered_group_id)) {
-#if defined(PSA_WANT_ALG_ECDH)
+        if (mbedtls_ssl_tls13_named_group_is_ecdhe(handshake->offered_group_id) ||
+            mbedtls_ssl_tls13_named_group_is_dhe(handshake->offered_group_id)) {
+#if defined(PSA_WANT_ALG_ECDH) || defined(PSA_WANT_ALG_FFDH)
+            psa_algorithm_t alg = 0;
+            if (mbedtls_ssl_tls13_named_group_is_ecdhe(handshake->offered_group_id)) {
+                alg = PSA_ALG_ECDH;
+            } else {
+                alg = PSA_ALG_FFDH;
+            }
             /* Compute ECDH shared secret. */
             psa_status_t status = PSA_ERROR_GENERIC_ERROR;
             psa_key_attributes_t key_attributes = PSA_KEY_ATTRIBUTES_INIT;
@@ -1504,7 +1511,7 @@
             }
 
             status = psa_raw_key_agreement(
-                PSA_ALG_ECDH, handshake->ecdh_psa_privkey,
+                alg, handshake->ecdh_psa_privkey,
                 handshake->ecdh_psa_peerkey, handshake->ecdh_psa_peerkey_len,
                 shared_secret, shared_secret_len, &shared_secret_len);
             if (status != PSA_SUCCESS) {
@@ -1521,7 +1528,7 @@
             }
 
             handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
-#endif /* PSA_WANT_ALG_ECDH */
+#endif /* PSA_WANT_ALG_ECDH || PSA_WANT_ALG_FFDH */
         } else {
             MBEDTLS_SSL_DEBUG_MSG(1, ("Group not supported."));
             return MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index dc3c2f0..31c6b17 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -836,7 +836,7 @@
 
 #define SSL_TLS1_3_PARSE_KEY_SHARES_EXT_NO_MATCH 1
 
-#if defined(PSA_WANT_ALG_ECDH)
+#if defined(PSA_WANT_ALG_ECDH) || defined(PSA_WANT_ALG_FFDH)
 /*
  *  ssl_tls13_parse_key_shares_ext() verifies whether the information in the
  *  extension is correct and stores the first acceptable key share and its
@@ -910,10 +910,11 @@
         }
 
         /*
-         * For now, we only support ECDHE groups.
+         * ECDHE and FFDHE groups are supported
          */
-        if (mbedtls_ssl_tls13_named_group_is_ecdhe(group)) {
-            MBEDTLS_SSL_DEBUG_MSG(2, ("ECDH group: %s (%04x)",
+        if (mbedtls_ssl_tls13_named_group_is_ecdhe(group) ||
+            mbedtls_ssl_tls13_named_group_is_dhe(group)) {
+            MBEDTLS_SSL_DEBUG_MSG(2, ("ECDH/FFDH group: %s (%04x)",
                                       mbedtls_ssl_named_group_to_str(group),
                                       group));
             ret = mbedtls_ssl_tls13_read_public_ecdhe_share(
@@ -938,7 +939,7 @@
     }
     return 0;
 }
-#endif /* PSA_WANT_ALG_ECDH */
+#endif /* PSA_WANT_ALG_ECDH || PSA_WANT_ALG_FFDH */
 
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_tls13_client_hello_has_exts(mbedtls_ssl_context *ssl,
@@ -1923,6 +1924,18 @@
         }
     } else
 #endif /* PSA_WANT_ALG_ECDH */
+#if defined(MBEDTLS_DHM_C)
+    if (mbedtls_ssl_tls13_named_group_is_dhe(named_group)) {
+        ret = mbedtls_ssl_tls13_generate_and_write_dhe_key_exchange(
+            ssl, named_group, buf, end, out_len);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "mbedtls_ssl_tls13_generate_and_write_dhe_key_exchange",
+                ret);
+            return ret;
+        }
+    } else
+#endif /* MBEDTLS_DHM_C */
     if (0 /* Other kinds of KEMs */) {
     } else {
         ((void) ssl);