Refactor call hierarchy for ECDH so that it goes through the driver wrapper in a similar fashion to ECDSA.
Add component_test_psa_config_accel_ecdh to all.sh to test key agreement driver wrapper with libtestdriver1.

Signed-off-by: Aditya Deshpande <aditya.deshpande@arm.com>
diff --git a/include/psa/crypto_config.h b/include/psa/crypto_config.h
index 5ab4fde..9f8866b 100644
--- a/include/psa/crypto_config.h
+++ b/include/psa/crypto_config.h
@@ -62,7 +62,7 @@
 #define PSA_WANT_ALG_CHACHA20_POLY1305          1
 #define PSA_WANT_ALG_CTR                        1
 #define PSA_WANT_ALG_DETERMINISTIC_ECDSA        1
-#define PSA_WANT_ALG_ECB_NO_PADDING             1
+//#define PSA_WANT_ALG_ECB_NO_PADDING             1
 #define PSA_WANT_ALG_ECDH                       1
 #define PSA_WANT_ALG_ECDSA                      1
 #define PSA_WANT_ALG_JPAKE                      1
@@ -86,7 +86,7 @@
 #define PSA_WANT_ALG_SHA_256                    1
 #define PSA_WANT_ALG_SHA_384                    1
 #define PSA_WANT_ALG_SHA_512                    1
-#define PSA_WANT_ALG_STREAM_CIPHER              1
+//#define PSA_WANT_ALG_STREAM_CIPHER              1
 #define PSA_WANT_ALG_TLS12_PRF                  1
 #define PSA_WANT_ALG_TLS12_PSK_TO_MS            1
 #define PSA_WANT_ALG_TLS12_ECJPAKE_TO_PMS       1
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 86b84bf..07f3151 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -5735,62 +5735,6 @@
 /****************************************************************/
 /* Key agreement */
 /****************************************************************/
-
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
-static psa_status_t psa_key_agreement_ecdh( const uint8_t *peer_key,
-                                            size_t peer_key_length,
-                                            const mbedtls_ecp_keypair *our_key,
-                                            uint8_t *shared_secret,
-                                            size_t shared_secret_size,
-                                            size_t *shared_secret_length )
-{
-    mbedtls_ecp_keypair *their_key = NULL;
-    mbedtls_ecdh_context ecdh;
-    psa_status_t status;
-    size_t bits = 0;
-    psa_ecc_family_t curve = mbedtls_ecc_group_to_psa( our_key->grp.id, &bits );
-    mbedtls_ecdh_init( &ecdh );
-
-    status = mbedtls_psa_ecp_load_representation(
-                 PSA_KEY_TYPE_ECC_PUBLIC_KEY(curve),
-                 bits,
-                 peer_key,
-                 peer_key_length,
-                 &their_key );
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    status = mbedtls_to_psa_error(
-        mbedtls_ecdh_get_params( &ecdh, their_key, MBEDTLS_ECDH_THEIRS ) );
-    if( status != PSA_SUCCESS )
-        goto exit;
-    status = mbedtls_to_psa_error(
-        mbedtls_ecdh_get_params( &ecdh, our_key, MBEDTLS_ECDH_OURS ) );
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    status = mbedtls_to_psa_error(
-        mbedtls_ecdh_calc_secret( &ecdh,
-                                  shared_secret_length,
-                                  shared_secret, shared_secret_size,
-                                  mbedtls_psa_get_random,
-                                  MBEDTLS_PSA_RANDOM_STATE ) );
-    if( status != PSA_SUCCESS )
-        goto exit;
-    if( PSA_BITS_TO_BYTES( bits ) != *shared_secret_length )
-        status = PSA_ERROR_CORRUPTION_DETECTED;
-
-exit:
-    if( status != PSA_SUCCESS )
-        mbedtls_platform_zeroize( shared_secret, shared_secret_size );
-    mbedtls_ecdh_free( &ecdh );
-    mbedtls_ecp_keypair_free( their_key );
-    mbedtls_free( their_key );
-
-    return( status );
-}
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_ECDH */
-
 #define PSA_KEY_AGREEMENT_MAX_SHARED_SECRET_SIZE MBEDTLS_ECP_MAX_BYTES
 
 psa_status_t psa_key_agreement_raw_builtin( const psa_key_attributes_t *attributes,
@@ -5807,24 +5751,12 @@
     {
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
         case PSA_ALG_ECDH:
-            if( ! PSA_KEY_TYPE_IS_ECC_KEY_PAIR( attributes->core.type ) )
-                return( PSA_ERROR_INVALID_ARGUMENT );
-            mbedtls_ecp_keypair *ecp = NULL;
-            psa_status_t status = mbedtls_psa_ecp_load_representation(
-                                      attributes->core.type,
-                                      attributes->core.bits,
-                                      key_buffer,
-                                      key_buffer_size,
-                                      &ecp );
-            if( status != PSA_SUCCESS )
-                return( status );
-            status = psa_key_agreement_ecdh( peer_key, peer_key_length,
-                                             ecp,
-                                             shared_secret, shared_secret_size,
-                                             shared_secret_length );
-            mbedtls_ecp_keypair_free( ecp );
-            mbedtls_free( ecp );
-            return( status );
+            return( mbedtls_psa_key_agreement_ecdh( attributes, key_buffer,
+                                                   key_buffer_size, alg,
+                                                   peer_key, peer_key_length,
+                                                   shared_secret,
+                                                   shared_secret_size,
+                                                   shared_secret_length ) );
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_ECDH */
         default:
             (void) attributes;
diff --git a/library/psa_crypto_ecp.c b/library/psa_crypto_ecp.c
index 29f53b9..97baef9 100644
--- a/library/psa_crypto_ecp.c
+++ b/library/psa_crypto_ecp.c
@@ -33,6 +33,7 @@
 #include "mbedtls/platform.h"
 
 #include <mbedtls/ecdsa.h>
+#include <mbedtls/ecdh.h>
 #include <mbedtls/ecp.h>
 #include <mbedtls/error.h>
 
@@ -464,4 +465,75 @@
 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_ECDSA) || \
         * defined(MBEDTLS_PSA_BUILTIN_ALG_DETERMINISTIC_ECDSA) */
 
+/****************************************************************/
+/* ECDH Key Agreement */
+/****************************************************************/
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
+psa_status_t mbedtls_psa_key_agreement_ecdh( 
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *peer_key, size_t peer_key_length,
+    uint8_t *shared_secret, size_t shared_secret_size,
+    size_t *shared_secret_length )
+{
+    if( ! PSA_KEY_TYPE_IS_ECC_KEY_PAIR( attributes->core.type ) ||
+        ! PSA_ALG_IS_ECDH(alg) )
+                return( PSA_ERROR_INVALID_ARGUMENT );
+    mbedtls_ecp_keypair *ecp = NULL;
+    psa_status_t status = mbedtls_psa_ecp_load_representation(
+                                attributes->core.type,
+                                attributes->core.bits,
+                                key_buffer,
+                                key_buffer_size,
+                                &ecp );
+    if( status != PSA_SUCCESS )
+        return( status );
+    mbedtls_ecp_keypair *their_key = NULL;
+    mbedtls_ecdh_context ecdh;
+    size_t bits = 0;
+    psa_ecc_family_t curve = mbedtls_ecc_group_to_psa( ecp->grp.id, &bits );
+    mbedtls_ecdh_init( &ecdh );
+
+    status = mbedtls_psa_ecp_load_representation(
+                PSA_KEY_TYPE_ECC_PUBLIC_KEY(curve),
+                bits,
+                peer_key,
+                peer_key_length,
+                &their_key );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    status = mbedtls_to_psa_error(
+        mbedtls_ecdh_get_params( &ecdh, their_key, MBEDTLS_ECDH_THEIRS ) );
+    if( status != PSA_SUCCESS )
+        goto exit;
+    status = mbedtls_to_psa_error(
+        mbedtls_ecdh_get_params( &ecdh, ecp, MBEDTLS_ECDH_OURS ) );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    status = mbedtls_to_psa_error(
+        mbedtls_ecdh_calc_secret( &ecdh,
+                                shared_secret_length,
+                                shared_secret, shared_secret_size,
+                                mbedtls_psa_get_random,
+                                MBEDTLS_PSA_RANDOM_STATE ) );
+    if( status != PSA_SUCCESS )
+        goto exit;
+    if( PSA_BITS_TO_BYTES( bits ) != *shared_secret_length )
+        status = PSA_ERROR_CORRUPTION_DETECTED;
+
+exit:
+    if( status != PSA_SUCCESS )
+        mbedtls_platform_zeroize( shared_secret, shared_secret_size );
+    mbedtls_ecdh_free( &ecdh );
+    mbedtls_ecp_keypair_free( their_key );
+    mbedtls_free( their_key );
+    mbedtls_ecp_keypair_free( ecp );
+    mbedtls_free( ecp );
+    return( status );
+}
+#endif /* MBEDTLS_PSA_BUILTIN_ALG_ECDH */
+
+
 #endif /* MBEDTLS_PSA_CRYPTO_C */
diff --git a/library/psa_crypto_ecp.h b/library/psa_crypto_ecp.h
index 429c062..5a7f6f2 100644
--- a/library/psa_crypto_ecp.h
+++ b/library/psa_crypto_ecp.h
@@ -218,4 +218,11 @@
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
     const uint8_t *signature, size_t signature_length );
+
+psa_status_t mbedtls_psa_key_agreement_ecdh( 
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *peer_key, size_t peer_key_length,
+    uint8_t *shared_secret, size_t shared_secret_size,
+    size_t *shared_secret_length );
 #endif /* PSA_CRYPTO_ECP_H */
diff --git a/tests/include/test/drivers/crypto_config_test_driver_extension.h b/tests/include/test/drivers/crypto_config_test_driver_extension.h
index 0bbca4a..fbfe8da 100644
--- a/tests/include/test/drivers/crypto_config_test_driver_extension.h
+++ b/tests/include/test/drivers/crypto_config_test_driver_extension.h
@@ -54,6 +54,14 @@
 #endif
 #endif
 
+#if defined(PSA_WANT_ALG_ECDH)
+#if defined(MBEDTLS_PSA_ACCEL_ALG_ECDH)
+#undef MBEDTLS_PSA_ACCEL_ALG_ECDH
+#else
+#define MBEDTLS_PSA_ACCEL_ALG_ECDH 1
+#endif
+#endif
+
 #if defined(PSA_WANT_ALG_MD5)
 #if defined(MBEDTLS_PSA_ACCEL_ALG_MD5)
 #undef MBEDTLS_PSA_ACCEL_ALG_MD5
@@ -202,7 +210,6 @@
 #define MBEDTLS_PSA_ACCEL_ALG_CCM 1
 #define MBEDTLS_PSA_ACCEL_ALG_CMAC 1
 #define MBEDTLS_PSA_ACCEL_ALG_ECB_NO_PADDING 1
-#define MBEDTLS_PSA_ACCEL_ALG_ECDH 1
 #define MBEDTLS_PSA_ACCEL_ALG_GCM 1
 #define MBEDTLS_PSA_ACCEL_ALG_HKDF 1
 #define MBEDTLS_PSA_ACCEL_ALG_HKDF_EXTRACT 1
@@ -215,6 +222,7 @@
 #define MBEDTLS_PSA_ACCEL_ALG_TLS12_PSK_TO_MS 1
 
 #if defined(MBEDTLS_PSA_ACCEL_ALG_ECDSA)
+#if defined(MBEDTLS_PSA_ACCEL_ALG_ECDH)
 #define MBEDTLS_PSA_ACCEL_ECC_BRAINPOOL_P_R1_256 1
 #define MBEDTLS_PSA_ACCEL_ECC_BRAINPOOL_P_R1_384 1
 #define MBEDTLS_PSA_ACCEL_ECC_BRAINPOOL_P_R1_512 1
@@ -229,6 +237,7 @@
 #define MBEDTLS_PSA_ACCEL_ECC_SECP_R1_384 1
 #define MBEDTLS_PSA_ACCEL_ECC_SECP_R1_521 1
 #endif
+#endif
 
 #define MBEDTLS_PSA_ACCEL_KEY_TYPE_DERIVE 1
 #define MBEDTLS_PSA_ACCEL_KEY_TYPE_HMAC 1
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index a4c6c86..7f1723b 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -1885,6 +1885,46 @@
     make test
 }
 
+component_test_psa_crypto_config_accel_ecdh () {
+    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated ECDH"
+
+    # Disable ALG_STREAM_CIPHER and ALG_ECB_NO_PADDING to avoid having
+    # partial support for cipher operations in the driver test library.
+    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_STREAM_CIPHER
+    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_ECB_NO_PADDING
+
+    # SHA384 needed for some ECDSA signature tests.
+    scripts/config.py -f tests/include/test/drivers/config_test_driver.h set MBEDTLS_SHA384_C
+    scripts/config.py -f tests/include/test/drivers/config_test_driver.h set MBEDTLS_SHA512_C
+
+    loc_accel_list="ALG_ECDH KEY_TYPE_ECC_KEY_PAIR KEY_TYPE_ECC_PUBLIC_KEY"
+    loc_accel_flags=$( echo "$loc_accel_list" | sed 's/[^ ]* */-DLIBTESTDRIVER1_MBEDTLS_PSA_ACCEL_&/g' )
+    make -C tests libtestdriver1.a CFLAGS=" -g3 $ASAN_CFLAGS $loc_accel_flags" LDFLAGS="$ASAN_CFLAGS"
+
+    # Restore test driver base configuration
+    scripts/config.py -f tests/include/test/drivers/config_test_driver.h unset MBEDTLS_SHA384_C
+    scripts/config.py -f tests/include/test/drivers/config_test_driver.h unset MBEDTLS_SHA512_C
+
+    scripts/config.py set MBEDTLS_PSA_CRYPTO_DRIVERS
+    scripts/config.py set MBEDTLS_PSA_CRYPTO_CONFIG
+    scripts/config.py unset MBEDTLS_USE_PSA_CRYPTO
+    scripts/config.py unset MBEDTLS_SSL_PROTO_TLS1_3
+    scripts/config.py unset MBEDTLS_ECDH_C
+    scripts/config.py unset MBEDTLS_KEY_EXCHANGE_ECDH_RSA_ENABLED
+    scripts/config.py unset MBEDTLS_KEY_EXCHANGE_ECDH_ECDSA_ENABLED
+    scripts/config.py unset MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
+    scripts/config.py unset MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED
+    scripts/config.py unset MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED
+
+    loc_accel_flags="$loc_accel_flags $( echo "$loc_accel_list" | sed 's/[^ ]* */-DMBEDTLS_PSA_ACCEL_&/g' )"
+    make CFLAGS="$ASAN_CFLAGS -O -Werror -I../tests/include -I../tests -I../../tests -DPSA_CRYPTO_DRIVER_TEST -DMBEDTLS_TEST_LIBTESTDRIVER1 -g3 $loc_accel_flags" LDFLAGS="-ltestdriver1 $ASAN_CFLAGS"
+
+    not grep mbedtls_ecdh_ library/ecdh.o
+
+    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated ECDH"
+    make test
+}
+
 component_test_psa_crypto_config_accel_rsa_signature () {
     msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated RSA signature"
 
diff --git a/tests/src/drivers/test_driver_key_agreement.c b/tests/src/drivers/test_driver_key_agreement.c
index 884899f..ccea61d 100644
--- a/tests/src/drivers/test_driver_key_agreement.c
+++ b/tests/src/drivers/test_driver_key_agreement.c
@@ -19,15 +19,22 @@
 
 #include <test/helpers.h>
 
+#if defined(MBEDTLS_PSA_CRYPTO_DRIVERS) && defined(PSA_CRYPTO_DRIVER_TEST)
+
 #include "psa/crypto.h"
 #include "psa_crypto_core.h"
+#include "psa_crypto_ecp.h"
 
 #include "test/drivers/key_agreement.h"
 #include "test/drivers/test_driver.h"
 
 #include <string.h>
+#include <stdio.h>
 
-#if defined(MBEDTLS_PSA_CRYPTO_DRIVERS) && defined(PSA_CRYPTO_DRIVER_TEST)
+#if defined(MBEDTLS_TEST_LIBTESTDRIVER1)
+#include "libtestdriver1/include/psa/crypto.h"
+#include "libtestdriver1/library/psa_crypto_ecp.h"
+#endif
 
 mbedtls_test_driver_key_agreement_hooks_t
     mbedtls_test_driver_key_agreement_hooks = MBEDTLS_TEST_DRIVER_KEY_AGREEMENT_INIT;
@@ -58,16 +65,30 @@
         return( PSA_SUCCESS );
     }
 
-    return( psa_key_agreement_raw_builtin(
+    if( PSA_ALG_IS_ECDH(alg) )
+    {
+#if defined(MBEDTLS_TEST_LIBTESTDRIVER1) && \
+    (LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_ALG_ECDH)
+        return( libtestdriver1_mbedtls_psa_key_agreement_ecdh(
+                    (const libtestdriver1_psa_key_attributes_t *) attributes,
+                    key_buffer, key_buffer_size,
+                    alg, peer_key, peer_key_length,
+                    shared_secret, shared_secret_size,
+                    shared_secret_length ) );
+#elif defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
+        return( mbedtls_psa_key_agreement_ecdh(
                 attributes,
-                key_buffer,
-                key_buffer_size,
-                alg,
-                peer_key,
-                peer_key_length,
-                shared_secret,
-                shared_secret_size,
+                key_buffer, key_buffer_size,
+                alg, peer_key, peer_key_length,
+                shared_secret, shared_secret_size,
                 shared_secret_length ) );
+#endif   
+    }
+    else
+    {
+        return( PSA_ERROR_INVALID_ARGUMENT );
+    }
+
 }
 
 #endif /* MBEDTLS_PSA_CRYPTO_DRIVERS && PSA_CRYPTO_DRIVER_TEST */