Merge pull request #7577 from mprse/ffdh_drivers

FFDH 3b: add driver testing (no TLS 1.3)
diff --git a/ChangeLog.d/driver-ffdh.txt b/ChangeLog.d/driver-ffdh.txt
new file mode 100644
index 0000000..1185133
--- /dev/null
+++ b/ChangeLog.d/driver-ffdh.txt
@@ -0,0 +1,3 @@
+Features
+   * Add a driver dispatch layer for FFDH keys, enabling alternative
+     implementations of FFDH through the driver entry points.
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 81427ac..85451bf 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -646,14 +646,11 @@
             if (psa_is_dh_key_size_valid(PSA_BYTES_TO_BITS(data_length)) == 0) {
                 return PSA_ERROR_INVALID_ARGUMENT;
             }
-
-            /* Copy the key material. */
-            memcpy(key_buffer, data, data_length);
-            *key_buffer_length = data_length;
-            *bits = PSA_BYTES_TO_BITS(data_length);
-            (void) key_buffer_size;
-
-            return PSA_SUCCESS;
+            return mbedtls_psa_ffdh_import_key(attributes,
+                                               data, data_length,
+                                               key_buffer, key_buffer_size,
+                                               key_buffer_length,
+                                               bits);
         }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) ||
         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY) */
@@ -1474,6 +1471,11 @@
 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) ||
         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY) */
     } else {
+        (void) key_buffer;
+        (void) key_buffer_size;
+        (void) data;
+        (void) data_size;
+        (void) data_length;
         return PSA_ERROR_NOT_SUPPORTED;
     }
 }
diff --git a/library/psa_crypto_ffdh.c b/library/psa_crypto_ffdh.c
index 6e34eaa..4550a72 100644
--- a/library/psa_crypto_ffdh.c
+++ b/library/psa_crypto_ffdh.c
@@ -26,9 +26,11 @@
 #include "psa_crypto_core.h"
 #include "psa_crypto_ffdh.h"
 #include "psa_crypto_random_impl.h"
+#include "mbedtls/platform.h"
 
-#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || \
-    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY)
+#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) ||   \
+    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY) || \
+    defined(MBEDTLS_PSA_BUILTIN_ALG_FFDH)
 static psa_status_t mbedtls_psa_ffdh_set_prime_generator(size_t key_size,
                                                          mbedtls_mpi *P,
                                                          mbedtls_mpi *G)
@@ -115,6 +117,119 @@
 
     return PSA_SUCCESS;
 }
+#endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR ||
+          MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY ||
+          MBEDTLS_PSA_BUILTIN_ALG_FFDH */
+
+#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || \
+    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY)
+psa_status_t mbedtls_psa_export_ffdh_public_key(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer,
+    size_t key_buffer_size,
+    uint8_t *data,
+    size_t data_size,
+    size_t *data_length)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    mbedtls_mpi GX, G, X, P;
+    psa_key_type_t type = attributes->core.type;
+
+    if (PSA_KEY_TYPE_IS_PUBLIC_KEY(type)) {
+        if (key_buffer_size > data_size) {
+            return PSA_ERROR_BUFFER_TOO_SMALL;
+        }
+        memcpy(data, key_buffer, key_buffer_size);
+        memset(data + key_buffer_size, 0,
+               data_size - key_buffer_size);
+        *data_length = key_buffer_size;
+        return PSA_SUCCESS;
+    }
+
+    mbedtls_mpi_init(&GX); mbedtls_mpi_init(&G);
+    mbedtls_mpi_init(&X); mbedtls_mpi_init(&P);
+
+    status = mbedtls_psa_ffdh_set_prime_generator(data_size, &P, &G);
+
+    if (status != PSA_SUCCESS) {
+        goto cleanup;
+    }
+
+    MBEDTLS_MPI_CHK(mbedtls_mpi_read_binary(&X, key_buffer,
+                                            key_buffer_size));
+
+    MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod(&GX, &G, &X, &P, NULL));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_write_binary(&GX, data, data_size));
+
+    *data_length = data_size;
+
+    ret = 0;
+cleanup:
+    mbedtls_mpi_free(&P); mbedtls_mpi_free(&G);
+    mbedtls_mpi_free(&X); mbedtls_mpi_free(&GX);
+
+    if (status == PSA_SUCCESS && ret != 0) {
+        status = mbedtls_to_psa_error(ret);
+    }
+
+    return status;
+}
+
+psa_status_t mbedtls_psa_ffdh_generate_key(
+    const psa_key_attributes_t *attributes,
+    uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
+{
+    mbedtls_mpi X, P;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    mbedtls_mpi_init(&P); mbedtls_mpi_init(&X);
+    (void) attributes;
+
+    status = mbedtls_psa_ffdh_set_prime_generator(key_buffer_size, &P, NULL);
+
+    if (status != PSA_SUCCESS) {
+        goto cleanup;
+    }
+
+    /* RFC7919: Traditional finite field Diffie-Hellman has each peer choose their
+        secret exponent from the range [2, P-2].
+        Select random value in range [3, P-1] and decrease it by 1. */
+    MBEDTLS_MPI_CHK(mbedtls_mpi_random(&X, 3, &P, mbedtls_psa_get_random,
+                                       MBEDTLS_PSA_RANDOM_STATE));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_sub_int(&X, &X, 1));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_write_binary(&X, key_buffer, key_buffer_size));
+    *key_buffer_length = key_buffer_size;
+
+cleanup:
+    mbedtls_mpi_free(&P); mbedtls_mpi_free(&X);
+    if (status == PSA_SUCCESS && ret != 0) {
+        return mbedtls_to_psa_error(ret);
+    }
+
+    return status;
+}
+
+psa_status_t mbedtls_psa_ffdh_import_key(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *data, size_t data_length,
+    uint8_t *key_buffer, size_t key_buffer_size,
+    size_t *key_buffer_length, size_t *bits)
+{
+    (void) attributes;
+
+    if (key_buffer_size < data_length) {
+        return PSA_ERROR_BUFFER_TOO_SMALL;
+    }
+    memcpy(key_buffer, data, data_length);
+    *key_buffer_length = data_length;
+    *bits = PSA_BYTES_TO_BITS(data_length);
+
+    return PSA_SUCCESS;
+}
+
+#endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR ||
+          MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY */
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_FFDH)
 psa_status_t mbedtls_psa_key_agreement_ffdh(
@@ -181,82 +296,4 @@
 }
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_FFDH */
 
-psa_status_t mbedtls_psa_export_ffdh_public_key(
-    const psa_key_attributes_t *attributes,
-    const uint8_t *key_buffer,
-    size_t key_buffer_size,
-    uint8_t *data,
-    size_t data_size,
-    size_t *data_length)
-{
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    mbedtls_mpi GX, G, X, P;
-    (void) attributes;
-
-    mbedtls_mpi_init(&GX); mbedtls_mpi_init(&G);
-    mbedtls_mpi_init(&X); mbedtls_mpi_init(&P);
-
-    status = mbedtls_psa_ffdh_set_prime_generator(data_size, &P, &G);
-
-    if (status != PSA_SUCCESS) {
-        goto cleanup;
-    }
-
-    MBEDTLS_MPI_CHK(mbedtls_mpi_read_binary(&X, key_buffer,
-                                            key_buffer_size));
-
-    MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod(&GX, &G, &X, &P, NULL));
-    MBEDTLS_MPI_CHK(mbedtls_mpi_write_binary(&GX, data, data_size));
-
-    *data_length = data_size;
-
-    ret = 0;
-cleanup:
-    mbedtls_mpi_free(&P); mbedtls_mpi_free(&G);
-    mbedtls_mpi_free(&X); mbedtls_mpi_free(&GX);
-
-    if (status == PSA_SUCCESS && ret != 0) {
-        status = mbedtls_to_psa_error(ret);
-    }
-
-    return status;
-}
-
-psa_status_t mbedtls_psa_ffdh_generate_key(
-    const psa_key_attributes_t *attributes,
-    uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
-{
-    mbedtls_mpi X, P;
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    mbedtls_mpi_init(&P); mbedtls_mpi_init(&X);
-    (void) attributes;
-
-    status = mbedtls_psa_ffdh_set_prime_generator(key_buffer_size, &P, NULL);
-
-    if (status != PSA_SUCCESS) {
-        goto cleanup;
-    }
-
-    /* RFC7919: Traditional finite field Diffie-Hellman has each peer choose their
-        secret exponent from the range [2, P-2].
-        Select random value in range [3, P-1] and decrease it by 1. */
-    MBEDTLS_MPI_CHK(mbedtls_mpi_random(&X, 3, &P, mbedtls_psa_get_random,
-                                       MBEDTLS_PSA_RANDOM_STATE));
-    MBEDTLS_MPI_CHK(mbedtls_mpi_sub_int(&X, &X, 1));
-    MBEDTLS_MPI_CHK(mbedtls_mpi_write_binary(&X, key_buffer, key_buffer_size));
-    *key_buffer_length = key_buffer_size;
-
-cleanup:
-    mbedtls_mpi_free(&P); mbedtls_mpi_free(&X);
-    if (status == PSA_SUCCESS && ret != 0) {
-        return mbedtls_to_psa_error(ret);
-    }
-
-    return status;
-}
-#endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR ||
-          MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY */
-
 #endif /* MBEDTLS_PSA_CRYPTO_C */
diff --git a/library/psa_crypto_ffdh.h b/library/psa_crypto_ffdh.h
index 62b05b2..5d7d951 100644
--- a/library/psa_crypto_ffdh.h
+++ b/library/psa_crypto_ffdh.h
@@ -112,4 +112,33 @@
     size_t key_buffer_size,
     size_t *key_buffer_length);
 
+/**
+ * \brief Import DH key.
+ *
+ * \note The signature of the function is that of a PSA driver import_key
+ *       entry point.
+ *
+ * \param[in]  attributes       The attributes for the key to import.
+ * \param[in]  data             The buffer containing the key data in import
+ *                              format.
+ * \param[in]  data_length      Size of the \p data buffer in bytes.
+ * \param[out] key_buffer       The buffer containing the key data in output
+ *                              format.
+ * \param[in]  key_buffer_size  Size of the \p key_buffer buffer in bytes. This
+ *                              size is greater or equal to \p data_length.
+ * \param[out] key_buffer_length  The length of the data written in \p
+ *                                key_buffer in bytes.
+ * \param[out] bits             The key size in number of bits.
+ *
+ * \retval #PSA_SUCCESS
+ *         The key was generated successfully.
+ * \retval #PSA_ERROR_BUFFER_TOO_SMALL
+ *         The size of \p key_buffer is too small.
+ */
+psa_status_t mbedtls_psa_ffdh_import_key(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *data, size_t data_length,
+    uint8_t *key_buffer, size_t key_buffer_size,
+    size_t *key_buffer_length, size_t *bits);
+
 #endif /* PSA_CRYPTO_FFDH_H */
diff --git a/tests/include/test/drivers/crypto_config_test_driver_extension.h b/tests/include/test/drivers/crypto_config_test_driver_extension.h
index 10d8e6e..f8b3a34 100644
--- a/tests/include/test/drivers/crypto_config_test_driver_extension.h
+++ b/tests/include/test/drivers/crypto_config_test_driver_extension.h
@@ -206,6 +206,14 @@
 #endif
 #endif
 
+#if defined(PSA_WANT_KEY_TYPE_DH_KEY_PAIR)
+#if defined(MBEDTLS_PSA_ACCEL_KEY_TYPE_DH_KEY_PAIR)
+#undef MBEDTLS_PSA_ACCEL_KEY_TYPE_DH_KEY_PAIR
+#else
+#define MBEDTLS_PSA_ACCEL_KEY_TYPE_DH_KEY_PAIR 1
+#endif
+#endif
+
 #if defined(PSA_WANT_KEY_TYPE_RSA_KEY_PAIR)
 #if defined(MBEDTLS_PSA_ACCEL_KEY_TYPE_RSA_KEY_PAIR)
 #undef MBEDTLS_PSA_ACCEL_KEY_TYPE_RSA_KEY_PAIR
@@ -222,6 +230,7 @@
 #endif
 #endif
 
+
 #if defined(PSA_WANT_ALG_TLS12_PRF)
 #if defined(MBEDTLS_PSA_ACCEL_ALG_TLS12_PRF)
 #undef MBEDTLS_PSA_ACCEL_ALG_TLS12_PRF
@@ -283,3 +292,4 @@
 #define MBEDTLS_PSA_ACCEL_KEY_TYPE_ECC_PUBLIC_KEY 1
 #define MBEDTLS_PSA_ACCEL_KEY_TYPE_RAW_DATA 1
 #define MBEDTLS_PSA_ACCEL_KEY_TYPE_RSA_PUBLIC_KEY 1
+#define MBEDTLS_PSA_ACCEL_KEY_TYPE_DH_PUBLIC_KEY 1
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index 78666b4..4b22040 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -2160,6 +2160,50 @@
     make test
 }
 
+component_test_psa_crypto_config_accel_ffdh () {
+    msg "build: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated FFDH"
+
+    # Algorithms and key types to accelerate
+    loc_accel_list="ALG_FFDH KEY_TYPE_DH_KEY_PAIR KEY_TYPE_DH_PUBLIC_KEY"
+
+    # Configure and build the test driver library
+    # -------------------------------------------
+
+    # Disable ALG_STREAM_CIPHER and ALG_ECB_NO_PADDING to avoid having
+    # partial support for cipher operations in the driver test library.
+    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_STREAM_CIPHER
+    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_ECB_NO_PADDING
+
+    loc_accel_flags=$( echo "$loc_accel_list" | sed 's/[^ ]* */-DLIBTESTDRIVER1_MBEDTLS_PSA_ACCEL_&/g' )
+    make -C tests libtestdriver1.a CFLAGS=" $ASAN_CFLAGS $loc_accel_flags" LDFLAGS="$ASAN_CFLAGS"
+
+    # Configure and build the main libraries
+    # --------------------------------------
+
+    # Start from default config (no USE_PSA or TLS 1.3)
+    scripts/config.py set MBEDTLS_PSA_CRYPTO_CONFIG
+
+    # Disable the module that's accelerated
+    scripts/config.py unset MBEDTLS_DHM_C
+
+    # Disable things that depend on it
+    scripts/config.py unset MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED
+    scripts/config.py unset MBEDTLS_KEY_EXCHANGE_DHE_RSA_ENABLED
+
+    # Build the main library
+    loc_accel_flags="$loc_accel_flags $( echo "$loc_accel_list" | sed 's/[^ ]* */-DMBEDTLS_PSA_ACCEL_&/g' )"
+    make CFLAGS="$ASAN_CFLAGS -O -Werror -I../tests/include -I../tests -I../../tests -DPSA_CRYPTO_DRIVER_TEST -DMBEDTLS_TEST_LIBTESTDRIVER1 $loc_accel_flags" LDFLAGS="-ltestdriver1 $ASAN_CFLAGS"
+
+    # Make sure this was not re-enabled by accident (additive config)
+    not grep mbedtls_dhm_ library/dhm.o
+
+    # Run the tests
+    # -------------
+
+    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated FFDH"
+    make test
+}
+
 component_test_psa_crypto_config_accel_pake() {
     msg "build: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE"
 
diff --git a/tests/src/drivers/test_driver_key_agreement.c b/tests/src/drivers/test_driver_key_agreement.c
index 843ebf9..6cfde20 100644
--- a/tests/src/drivers/test_driver_key_agreement.c
+++ b/tests/src/drivers/test_driver_key_agreement.c
@@ -34,6 +34,7 @@
 #if defined(MBEDTLS_TEST_LIBTESTDRIVER1)
 #include "libtestdriver1/include/psa/crypto.h"
 #include "libtestdriver1/library/psa_crypto_ecp.h"
+#include "libtestdriver1/library/psa_crypto_ffdh.h"
 #endif
 
 mbedtls_test_driver_key_agreement_hooks_t
@@ -101,8 +102,8 @@
         defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_ALG_FFDH))
         return libtestdriver1_mbedtls_psa_key_agreement_ffdh(
             (const libtestdriver1_psa_key_attributes_t *) attributes,
+            peer_key, peer_key_length,
             key_buffer, key_buffer_size,
-            alg, peer_key, peer_key_length,
             shared_secret, shared_secret_size,
             shared_secret_length);
 #elif defined(MBEDTLS_PSA_BUILTIN_ALG_FFDH)
diff --git a/tests/src/drivers/test_driver_key_management.c b/tests/src/drivers/test_driver_key_management.c
index a3ff2dd..3ff1053 100644
--- a/tests/src/drivers/test_driver_key_management.c
+++ b/tests/src/drivers/test_driver_key_management.c
@@ -25,6 +25,7 @@
 #include "psa_crypto_core.h"
 #include "psa_crypto_ecp.h"
 #include "psa_crypto_rsa.h"
+#include "psa_crypto_ffdh.h"
 #include "mbedtls/ecp.h"
 #include "mbedtls/error.h"
 
@@ -36,6 +37,7 @@
 #if defined(MBEDTLS_TEST_LIBTESTDRIVER1)
 #include "libtestdriver1/library/psa_crypto_ecp.h"
 #include "libtestdriver1/library/psa_crypto_rsa.h"
+#include "libtestdriver1/library/psa_crypto_ffdh.h"
 #endif
 
 #include <string.h>
@@ -240,6 +242,17 @@
         return mbedtls_psa_rsa_generate_key(
             attributes, key, key_size, key_length);
 #endif
+    } else if (PSA_KEY_TYPE_IS_DH(psa_get_key_type(attributes))
+               && PSA_KEY_TYPE_IS_KEY_PAIR(psa_get_key_type(attributes))) {
+#if defined(MBEDTLS_TEST_LIBTESTDRIVER1) && \
+        defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR)
+        return libtestdriver1_mbedtls_psa_ffdh_generate_key(
+            (const libtestdriver1_psa_key_attributes_t *) attributes,
+            key, key_size, key_length);
+#elif defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR)
+        return mbedtls_psa_ffdh_generate_key(
+            attributes, key, key_size, key_length);
+#endif
     }
 
     (void) attributes;
@@ -309,8 +322,24 @@
             key_buffer, key_buffer_size,
             key_buffer_length, bits);
 #endif
+    } else if (PSA_KEY_TYPE_IS_DH(type)) {
+#if defined(MBEDTLS_TEST_LIBTESTDRIVER1) && \
+        (defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || \
+        defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY))
+        return libtestdriver1_mbedtls_psa_ffdh_import_key(
+            (const libtestdriver1_psa_key_attributes_t *) attributes,
+            data, data_length,
+            key_buffer, key_buffer_size,
+            key_buffer_length, bits);
+#elif defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
+        defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
+        return mbedtls_psa_ffdh_import_key(
+            attributes,
+            data, data_length,
+            key_buffer, key_buffer_size,
+            key_buffer_length, bits);
+#endif
     }
-
     (void) data;
     (void) data_length;
     (void) key_buffer;
@@ -560,6 +589,21 @@
             key_buffer, key_buffer_size,
             data, data_size, data_length);
 #endif
+    } else if (PSA_KEY_TYPE_IS_DH(key_type)) {
+#if defined(MBEDTLS_TEST_LIBTESTDRIVER1) && \
+        (defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || \
+        defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY))
+        return libtestdriver1_mbedtls_psa_export_ffdh_public_key(
+            (const libtestdriver1_psa_key_attributes_t *) attributes,
+            key_buffer, key_buffer_size,
+            data, data_size, data_length);
+#elif defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || \
+        defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY)
+        return mbedtls_psa_export_ffdh_public_key(
+            attributes,
+            key_buffer, key_buffer_size,
+            data, data_size, data_length);
+#endif
     }
 
     (void) key_buffer;