Merge remote-tracking branch 'origin/pr/2338' into development
diff --git a/include/mbedtls/check_config.h b/include/mbedtls/check_config.h
index 3d47899..ea05938 100644
--- a/include/mbedtls/check_config.h
+++ b/include/mbedtls/check_config.h
@@ -114,14 +114,15 @@
 #endif
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)           && \
-    ( defined(MBEDTLS_ECDH_COMPUTE_SHARED_ALT) || \
+    ( defined(MBEDTLS_USE_PSA_CRYPTO)          || \
+      defined(MBEDTLS_ECDH_COMPUTE_SHARED_ALT) || \
       defined(MBEDTLS_ECDH_GEN_PUBLIC_ALT)     || \
       defined(MBEDTLS_ECDSA_SIGN_ALT)          || \
       defined(MBEDTLS_ECDSA_VERIFY_ALT)        || \
       defined(MBEDTLS_ECDSA_GENKEY_ALT)        || \
       defined(MBEDTLS_ECP_INTERNAL_ALT)        || \
       defined(MBEDTLS_ECP_ALT) )
-#error "MBEDTLS_ECP_RESTARTABLE defined, but it cannot coexist with an alternative ECP implementation"
+#error "MBEDTLS_ECP_RESTARTABLE defined, but it cannot coexist with an alternative or PSA-based ECP implementation"
 #endif
 
 #if defined(MBEDTLS_ECDSA_DETERMINISTIC) && !defined(MBEDTLS_HMAC_DRBG_C)
diff --git a/include/mbedtls/psa_util.h b/include/mbedtls/psa_util.h
index fbf25e6..b0c0428 100644
--- a/include/mbedtls/psa_util.h
+++ b/include/mbedtls/psa_util.h
@@ -43,6 +43,8 @@
 #include "pk.h"
 #include "oid.h"
 
+#include <string.h>
+
 /* Translations for symmetric crypto. */
 
 static inline psa_key_type_t mbedtls_psa_translate_cipher_type(
@@ -233,6 +235,86 @@
      return( -1 );
 }
 
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH 1
+
+#if defined(MBEDTLS_ECP_DP_SECP192R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP192R1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_SECP224R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP224R1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_SECP256R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP256R1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_SECP384R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP384R1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_SECP521R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 521 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 521 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP521R1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_SECP192K1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP192K1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_SECP224K1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP224K1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_SECP256K1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_SECP256K1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_BP256R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_BP256R1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_BP384R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_BP384R1_ENABLED */
+
+#if defined(MBEDTLS_ECP_DP_BP512R1_ENABLED)
+#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 512 + 7 ) / 8 ) + 1 )
+#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 512 + 7 ) / 8 ) + 1 )
+#endif
+#endif /* MBEDTLS_ECP_DP_BP512R1_ENABLED */
+
+
 static inline psa_ecc_curve_t mbedtls_psa_translate_ecc_group( mbedtls_ecp_group_id grpid )
 {
     switch( grpid )
@@ -294,6 +376,7 @@
     }
 }
 
+
 #define MBEDTLS_PSA_ECC_KEY_BITS_OF_CURVE( curve )                \
     ( curve == PSA_ECC_CURVE_SECP192R1        ? 192 :             \
       curve == PSA_ECC_CURVE_SECP224R1        ? 224 :             \
@@ -352,6 +435,48 @@
     return( (psa_ecc_curve_t) tls_ecc_grp_reg_id );
 }
 
+/* This function takes a buffer holding an EC public key
+ * exported through psa_export_public_key(), and converts
+ * it into an ECPoint structure to be put into a ClientKeyExchange
+ * message in an ECDHE exchange.
+ *
+ * Both the present and the foreseeable future format of EC public keys
+ * used by PSA have the ECPoint structure contained in the exported key
+ * as a subbuffer, and the function merely selects this subbuffer instead
+ * of making a copy.
+ */
+static inline int mbedtls_psa_tls_psa_ec_to_ecpoint( unsigned char *src,
+                                                     size_t srclen,
+                                                     unsigned char **dst,
+                                                     size_t *dstlen )
+{
+    *dst = src;
+    *dstlen = srclen;
+    return( 0 );
+}
+
+/* This function takes a buffer holding an ECPoint structure
+ * (as contained in a TLS ServerKeyExchange message for ECDHE
+ * exchanges) and converts it into a format that the PSA key
+ * agreement API understands.
+ */
+static inline int mbedtls_psa_tls_ecpoint_to_psa_ec( psa_ecc_curve_t curve,
+                                                     unsigned char const *src,
+                                                     size_t srclen,
+                                                     unsigned char *dst,
+                                                     size_t dstlen,
+                                                     size_t *olen )
+{
+    ((void) curve);
+
+    if( srclen > dstlen )
+        return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );
+
+    memcpy( dst, src, srclen );
+    *olen = srclen;
+    return( 0 );
+}
+
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #endif /* MBEDTLS_PSA_UTIL_H */
diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h
index 3159cd3..be7f41b 100644
--- a/include/mbedtls/ssl_internal.h
+++ b/include/mbedtls/ssl_internal.h
@@ -57,6 +57,11 @@
 #include "ecjpake.h"
 #endif
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#include "psa/crypto.h"
+#include "psa_util.h"
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 #if ( defined(__ARMCC_VERSION) || defined(_MSC_VER) ) && \
     !defined(inline) && !defined(__cplusplus)
 #define inline __inline
@@ -280,7 +285,15 @@
 #endif
 #if defined(MBEDTLS_ECDH_C)
     mbedtls_ecdh_context ecdh_ctx;              /*!<  ECDH key exchange       */
-#endif
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_ecc_curve_t ecdh_psa_curve;
+    psa_key_handle_t ecdh_psa_privkey;
+    unsigned char ecdh_psa_peerkey[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
+    size_t ecdh_psa_peerkey_len;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+#endif /* MBEDTLS_ECDH_C */
+
 #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
     mbedtls_ecjpake_context ecjpake_ctx;        /*!< EC J-PAKE key exchange */
 #if defined(MBEDTLS_SSL_CLI_C)
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index ec36401..87fa1e0 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -39,6 +39,10 @@
 #include "mbedtls/ssl.h"
 #include "mbedtls/ssl_internal.h"
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#include "mbedtls/psa_util.h"
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 #include <string.h>
 
 #include <stdint.h>
@@ -2109,6 +2113,64 @@
           MBEDTLS_KEY_EXCHANGE_ECDH_RSA_ENABLED ||
           MBEDTLS_KEY_EXCHANGE_ECDH_ECDSA_ENABLED */
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
+        ( defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||     \
+          defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) )
+static int ssl_parse_server_ecdh_params_psa( mbedtls_ssl_context *ssl,
+                                             unsigned char **p,
+                                             unsigned char *end )
+{
+    uint16_t tls_id;
+    uint8_t ecpoint_len;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+    /*
+     * Parse ECC group
+     */
+
+    if( end - *p < 4 )
+        return( MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+
+    /* First byte is curve_type; only named_curve is handled */
+    if( *(*p)++ != MBEDTLS_ECP_TLS_NAMED_CURVE )
+        return( MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+
+    /* Next two bytes are the namedcurve value */
+    tls_id = *(*p)++;
+    tls_id <<= 8;
+    tls_id |= *(*p)++;
+
+    /* Convert EC group to PSA key type. */
+    if( ( handshake->ecdh_psa_curve =
+          mbedtls_psa_parse_tls_ecc_group( tls_id ) ) == 0 )
+    {
+        return( MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+    }
+
+    /*
+     * Put peer's ECDH public key in the format understood by PSA.
+     */
+
+    ecpoint_len = *(*p)++;
+    if( (size_t)( end - *p ) < ecpoint_len )
+        return( MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+
+    if( mbedtls_psa_tls_ecpoint_to_psa_ec( handshake->ecdh_psa_curve,
+                                    *p, ecpoint_len,
+                                    handshake->ecdh_psa_peerkey,
+                                    sizeof( handshake->ecdh_psa_peerkey ),
+                                    &handshake->ecdh_psa_peerkey_len ) != 0 )
+    {
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    *p += ecpoint_len;
+    return( 0 );
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+            ( MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED ||
+              MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED ) */
+
 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||                     \
     defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) ||                   \
     defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
@@ -2510,6 +2572,24 @@
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_RSA_ENABLED ||
           MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
+        ( defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||     \
+          defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) )
+    if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA ||
+        ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA )
+    {
+        if( ssl_parse_server_ecdh_params_psa( ssl, &p, end ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) );
+            mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
+                                            MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER );
+            return( MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+        }
+    }
+    else
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+            ( MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED ||
+              MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED ) */
 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||                     \
     defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) ||                     \
     defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED)
@@ -2938,7 +3018,9 @@
 static int ssl_write_client_key_exchange( mbedtls_ssl_context *ssl )
 {
     int ret;
-    size_t i, n;
+
+    size_t header_len;
+    size_t content_len;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info =
         ssl->transform_negotiate->ciphersuite_info;
 
@@ -2950,16 +3032,16 @@
         /*
          * DHM key exchange -- send G^X mod P
          */
-        n = ssl->handshake->dhm_ctx.len;
+        content_len = ssl->handshake->dhm_ctx.len;
 
-        ssl->out_msg[4] = (unsigned char)( n >> 8 );
-        ssl->out_msg[5] = (unsigned char)( n      );
-        i = 6;
+        ssl->out_msg[4] = (unsigned char)( content_len >> 8 );
+        ssl->out_msg[5] = (unsigned char)( content_len      );
+        header_len = 6;
 
         ret = mbedtls_dhm_make_public( &ssl->handshake->dhm_ctx,
-                                (int) mbedtls_mpi_size( &ssl->handshake->dhm_ctx.P ),
-                               &ssl->out_msg[i], n,
-                                ssl->conf->f_rng, ssl->conf->p_rng );
+                           (int) mbedtls_mpi_size( &ssl->handshake->dhm_ctx.P ),
+                           &ssl->out_msg[header_len], content_len,
+                           ssl->conf->f_rng, ssl->conf->p_rng );
         if( ret != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_dhm_make_public", ret );
@@ -2970,10 +3052,10 @@
         MBEDTLS_SSL_DEBUG_MPI( 3, "DHM: GX", &ssl->handshake->dhm_ctx.GX );
 
         if( ( ret = mbedtls_dhm_calc_secret( &ssl->handshake->dhm_ctx,
-                                      ssl->handshake->premaster,
-                                      MBEDTLS_PREMASTER_SIZE,
-                                     &ssl->handshake->pmslen,
-                                      ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
+                                   ssl->handshake->premaster,
+                                   MBEDTLS_PREMASTER_SIZE,
+                                   &ssl->handshake->pmslen,
+                                   ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_dhm_calc_secret", ret );
             return( ret );
@@ -2983,6 +3065,119 @@
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_RSA_ENABLED */
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
+        ( defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||     \
+          defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) )
+    if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA ||
+        ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA )
+    {
+        psa_status_t status;
+        psa_key_policy_t policy;
+
+        mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+        unsigned char own_pubkey[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
+        size_t own_pubkey_len;
+        unsigned char *own_pubkey_ecpoint;
+        size_t own_pubkey_ecpoint_len;
+
+        psa_crypto_generator_t generator = PSA_CRYPTO_GENERATOR_INIT;
+
+        header_len = 4;
+
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based ECDH computation." ) );
+
+        /*
+         * Generate EC private key for ECDHE exchange.
+         */
+
+        /* Allocate a new key slot for the private key. */
+
+        status = psa_allocate_key( &handshake->ecdh_psa_privkey );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+        /* The master secret is obtained from the shared ECDH secret by
+         * applying the TLS 1.2 PRF with a specific salt and label. While
+         * the PSA Crypto API encourages combining key agreement schemes
+         * such as ECDH with fixed KDFs such as TLS 1.2 PRF, it does not
+         * yet support the provisioning of salt + label to the KDF.
+         * For the time being, we therefore need to split the computation
+         * of the ECDH secret and the application of the TLS 1.2 PRF. */
+        policy = psa_key_policy_init();
+        psa_key_policy_set_usage( &policy,
+                                  PSA_KEY_USAGE_DERIVE,
+                                  PSA_ALG_ECDH( PSA_ALG_SELECT_RAW ) );
+        status = psa_set_key_policy( handshake->ecdh_psa_privkey, &policy );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+        /* Generate ECDH private key. */
+        status = psa_generate_key( handshake->ecdh_psa_privkey,
+                          PSA_KEY_TYPE_ECC_KEYPAIR( handshake->ecdh_psa_curve ),
+                          MBEDTLS_PSA_ECC_KEY_BITS_OF_CURVE( handshake->ecdh_psa_curve ),
+                          NULL, 0 );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+        /* Export the public part of the ECDH private key from PSA
+         * and convert it to ECPoint format used in ClientKeyExchange. */
+        status = psa_export_public_key( handshake->ecdh_psa_privkey,
+                                        own_pubkey, sizeof( own_pubkey ),
+                                        &own_pubkey_len );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+        if( mbedtls_psa_tls_psa_ec_to_ecpoint( own_pubkey,
+                                               own_pubkey_len,
+                                               &own_pubkey_ecpoint,
+                                               &own_pubkey_ecpoint_len ) != 0 )
+        {
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+        }
+
+        /* Copy ECPoint structure to outgoing message buffer. */
+        ssl->out_msg[header_len] = own_pubkey_ecpoint_len;
+        memcpy( ssl->out_msg + header_len + 1,
+                own_pubkey_ecpoint, own_pubkey_ecpoint_len );
+        content_len = own_pubkey_ecpoint_len + 1;
+
+        /* Compute ECDH shared secret. */
+        status = psa_key_agreement( &generator,
+                                    handshake->ecdh_psa_privkey,
+                                    handshake->ecdh_psa_peerkey,
+                                    handshake->ecdh_psa_peerkey_len,
+                                    PSA_ALG_ECDH( PSA_ALG_SELECT_RAW ) );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+        /* The ECDH secret is the premaster secret used for key derivation. */
+
+        ssl->handshake->pmslen =
+            MBEDTLS_PSA_ECC_KEY_BYTES_OF_CURVE( handshake->ecdh_psa_curve );
+
+        status = psa_generator_read( &generator,
+                                     ssl->handshake->premaster,
+                                     ssl->handshake->pmslen );
+        if( status != PSA_SUCCESS )
+        {
+            psa_generator_abort( &generator );
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+        }
+
+        status = psa_generator_abort( &generator );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+        status = psa_destroy_key( handshake->ecdh_psa_privkey );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+        handshake->ecdh_psa_privkey = 0;
+    }
+    else
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+            ( MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED ||
+              MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED ) */
 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||                     \
     defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) ||                   \
     defined(MBEDTLS_KEY_EXCHANGE_ECDH_RSA_ENABLED) ||                      \
@@ -2995,7 +3190,7 @@
         /*
          * ECDH key exchange -- send client public value
          */
-        i = 4;
+        header_len = 4;
 
 #if defined(MBEDTLS_SSL__ECP_RESTARTABLE)
         if( ssl->handshake->ecrs_enabled )
@@ -3008,8 +3203,8 @@
 #endif
 
         ret = mbedtls_ecdh_make_public( &ssl->handshake->ecdh_ctx,
-                                &n,
-                                &ssl->out_msg[i], 1000,
+                                &content_len,
+                                &ssl->out_msg[header_len], 1000,
                                 ssl->conf->f_rng, ssl->conf->p_rng );
         if( ret != 0 )
         {
@@ -3027,19 +3222,19 @@
 #if defined(MBEDTLS_SSL__ECP_RESTARTABLE)
         if( ssl->handshake->ecrs_enabled )
         {
-            ssl->handshake->ecrs_n = n;
+            ssl->handshake->ecrs_n = content_len;
             ssl->handshake->ecrs_state = ssl_ecrs_cke_ecdh_calc_secret;
         }
 
 ecdh_calc_secret:
         if( ssl->handshake->ecrs_enabled )
-            n = ssl->handshake->ecrs_n;
+            content_len = ssl->handshake->ecrs_n;
 #endif
         if( ( ret = mbedtls_ecdh_calc_secret( &ssl->handshake->ecdh_ctx,
-                                      &ssl->handshake->pmslen,
-                                       ssl->handshake->premaster,
-                                       MBEDTLS_MPI_MAX_SIZE,
-                                       ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
+                                   &ssl->handshake->pmslen,
+                                   ssl->handshake->premaster,
+                                   MBEDTLS_MPI_MAX_SIZE,
+                                   ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_calc_secret", ret );
 #if defined(MBEDTLS_SSL__ECP_RESTARTABLE)
@@ -3071,26 +3266,28 @@
             return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
         }
 
-        i = 4;
-        n = ssl->conf->psk_identity_len;
+        header_len = 4;
+        content_len = ssl->conf->psk_identity_len;
 
-        if( i + 2 + n > MBEDTLS_SSL_OUT_CONTENT_LEN )
+        if( header_len + 2 + content_len > MBEDTLS_SSL_OUT_CONTENT_LEN )
         {
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "psk identity too long or "
                                         "SSL buffer too short" ) );
             return( MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL );
         }
 
-        ssl->out_msg[i++] = (unsigned char)( n >> 8 );
-        ssl->out_msg[i++] = (unsigned char)( n      );
+        ssl->out_msg[header_len++] = (unsigned char)( content_len >> 8 );
+        ssl->out_msg[header_len++] = (unsigned char)( content_len      );
 
-        memcpy( ssl->out_msg + i, ssl->conf->psk_identity, ssl->conf->psk_identity_len );
-        i += ssl->conf->psk_identity_len;
+        memcpy( ssl->out_msg + header_len,
+                ssl->conf->psk_identity,
+                ssl->conf->psk_identity_len );
+        header_len += ssl->conf->psk_identity_len;
 
 #if defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
         if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_PSK )
         {
-            n = 0;
+            content_len = 0;
         }
         else
 #endif
@@ -3103,7 +3300,8 @@
                 return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE );
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
-            if( ( ret = ssl_write_encrypted_pms( ssl, i, &n, 2 ) ) != 0 )
+            if( ( ret = ssl_write_encrypted_pms( ssl, header_len,
+                                                 &content_len, 2 ) ) != 0 )
                 return( ret );
         }
         else
@@ -3120,21 +3318,22 @@
             /*
              * ClientDiffieHellmanPublic public (DHM send G^X mod P)
              */
-            n = ssl->handshake->dhm_ctx.len;
+            content_len = ssl->handshake->dhm_ctx.len;
 
-            if( i + 2 + n > MBEDTLS_SSL_OUT_CONTENT_LEN )
+            if( header_len + 2 + content_len >
+                MBEDTLS_SSL_OUT_CONTENT_LEN )
             {
                 MBEDTLS_SSL_DEBUG_MSG( 1, ( "psk identity or DHM size too long"
                                             " or SSL buffer too short" ) );
                 return( MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL );
             }
 
-            ssl->out_msg[i++] = (unsigned char)( n >> 8 );
-            ssl->out_msg[i++] = (unsigned char)( n      );
+            ssl->out_msg[header_len++] = (unsigned char)( content_len >> 8 );
+            ssl->out_msg[header_len++] = (unsigned char)( content_len      );
 
             ret = mbedtls_dhm_make_public( &ssl->handshake->dhm_ctx,
                     (int) mbedtls_mpi_size( &ssl->handshake->dhm_ctx.P ),
-                    &ssl->out_msg[i], n,
+                    &ssl->out_msg[header_len], content_len,
                     ssl->conf->f_rng, ssl->conf->p_rng );
             if( ret != 0 )
             {
@@ -3156,8 +3355,10 @@
             /*
              * ClientECDiffieHellmanPublic public;
              */
-            ret = mbedtls_ecdh_make_public( &ssl->handshake->ecdh_ctx, &n,
-                    &ssl->out_msg[i], MBEDTLS_SSL_OUT_CONTENT_LEN - i,
+            ret = mbedtls_ecdh_make_public( &ssl->handshake->ecdh_ctx,
+                    &content_len,
+                    &ssl->out_msg[header_len],
+                    MBEDTLS_SSL_OUT_CONTENT_LEN - header_len,
                     ssl->conf->f_rng, ssl->conf->p_rng );
             if( ret != 0 )
             {
@@ -3198,8 +3399,9 @@
 #if defined(MBEDTLS_KEY_EXCHANGE_RSA_ENABLED)
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA )
     {
-        i = 4;
-        if( ( ret = ssl_write_encrypted_pms( ssl, i, &n, 0 ) ) != 0 )
+        header_len = 4;
+        if( ( ret = ssl_write_encrypted_pms( ssl, header_len,
+                                             &content_len, 0 ) ) != 0 )
             return( ret );
     }
     else
@@ -3207,10 +3409,12 @@
 #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE )
     {
-        i = 4;
+        header_len = 4;
 
         ret = mbedtls_ecjpake_write_round_two( &ssl->handshake->ecjpake_ctx,
-                ssl->out_msg + i, MBEDTLS_SSL_OUT_CONTENT_LEN - i, &n,
+                ssl->out_msg + header_len,
+                MBEDTLS_SSL_OUT_CONTENT_LEN - header_len,
+                &content_len,
                 ssl->conf->f_rng, ssl->conf->p_rng );
         if( ret != 0 )
         {
@@ -3235,7 +3439,7 @@
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
     }
 
-    ssl->out_msglen  = i + n;
+    ssl->out_msglen  = header_len + content_len;
     ssl->out_msgtype = MBEDTLS_SSL_MSG_HANDSHAKE;
     ssl->out_msg[0]  = MBEDTLS_SSL_HS_CLIENT_KEY_EXCHANGE;
 
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 6e50c1e..4c23f0e 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -9410,6 +9410,11 @@
     ssl_buffering_free( ssl );
 #endif
 
+#if defined(MBEDTLS_ECDH_C) &&                  \
+    defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_destroy_key( handshake->ecdh_psa_privkey );
+#endif /* MBEDTLS_ECDH_C && MBEDTLS_USE_PSA_CRYPTO */
+
     mbedtls_platform_zeroize( handshake,
                               sizeof( mbedtls_ssl_handshake_params ) );
 }
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index 9b061d3..2688159 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -852,6 +852,7 @@
     msg "build: cmake, full config + MBEDTLS_USE_PSA_CRYPTO, ASan"
     scripts/config.pl full
     scripts/config.pl unset MBEDTLS_MEMORY_BACKTRACE # too slow for tests
+    scripts/config.pl unset MBEDTLS_ECP_RESTARTABLE  # restartable ECC not supported through PSA
     scripts/config.pl set MBEDTLS_PSA_CRYPTO_C
     scripts/config.pl set MBEDTLS_USE_PSA_CRYPTO
     CC=gcc cmake -D USE_CRYPTO_SUBMODULE=1 -D CMAKE_BUILD_TYPE:String=Asan .
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index 1d852ba..ff05f64 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -781,6 +781,30 @@
                 -C "Failed to setup PSA-based cipher context"\
                 -S "Failed to setup PSA-based cipher context"\
                 -s "Protocol is TLSv1.2" \
+                -c "Perform PSA-based ECDH computation."\
+                -c "Perform PSA-based computation of digest of ServerKeyExchange" \
+                -S "error" \
+                -C "error"
+}
+
+run_test_psa_force_curve() {
+    requires_config_enabled MBEDTLS_USE_PSA_CRYPTO
+    run_test    "PSA - ECDH with $1" \
+                "$P_SRV debug_level=4 force_version=tls1_2" \
+                "$P_CLI debug_level=4 force_version=tls1_2 force_ciphersuite=TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256 curves=$1" \
+                0 \
+                -c "Successfully setup PSA-based decryption cipher context" \
+                -c "Successfully setup PSA-based encryption cipher context" \
+                -c "PSA calc verify" \
+                -c "calc PSA finished" \
+                -s "Successfully setup PSA-based decryption cipher context" \
+                -s "Successfully setup PSA-based encryption cipher context" \
+                -s "PSA calc verify" \
+                -s "calc PSA finished" \
+                -C "Failed to setup PSA-based cipher context"\
+                -S "Failed to setup PSA-based cipher context"\
+                -s "Protocol is TLSv1.2" \
+                -c "Perform PSA-based ECDH computation."\
                 -c "Perform PSA-based computation of digest of ServerKeyExchange" \
                 -S "error" \
                 -C "error"
@@ -944,6 +968,29 @@
 run_test_psa TLS-ECDHE-ECDSA-WITH-AES-128-CBC-SHA256
 run_test_psa TLS-ECDHE-ECDSA-WITH-AES-256-CBC-SHA384
 
+requires_config_enabled MBEDTLS_ECP_DP_SECP521R1_ENABLED
+run_test_psa_force_curve "secp521r1"
+requires_config_enabled MBEDTLS_ECP_DP_BP512R1_ENABLED
+run_test_psa_force_curve "brainpoolP512r1"
+requires_config_enabled MBEDTLS_ECP_DP_SECP384R1_ENABLED
+run_test_psa_force_curve "secp384r1"
+requires_config_enabled MBEDTLS_ECP_DP_BP384R1_ENABLED
+run_test_psa_force_curve "brainpoolP384r1"
+requires_config_enabled MBEDTLS_ECP_DP_SECP256R1_ENABLED
+run_test_psa_force_curve "secp256r1"
+requires_config_enabled MBEDTLS_ECP_DP_SECP256K1_ENABLED
+run_test_psa_force_curve "secp256k1"
+requires_config_enabled MBEDTLS_ECP_DP_BP256R1_ENABLED
+run_test_psa_force_curve "brainpoolP256r1"
+requires_config_enabled MBEDTLS_ECP_DP_SECP224R1_ENABLED
+run_test_psa_force_curve "secp224r1"
+requires_config_enabled MBEDTLS_ECP_DP_SECP224K1_ENABLED
+run_test_psa_force_curve "secp224k1"
+requires_config_enabled MBEDTLS_ECP_DP_SECP192R1_ENABLED
+run_test_psa_force_curve "secp192r1"
+requires_config_enabled MBEDTLS_ECP_DP_SECP192K1_ENABLED
+run_test_psa_force_curve "secp192k1"
+
 # Test current time in ServerHello
 requires_config_enabled MBEDTLS_HAVE_TIME
 run_test    "ServerHello contains gmt_unix_time" \