Merge remote-tracking branch 'origin/pr/2382' into development-psa
diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h
index fced2cb..f1148af 100644
--- a/include/mbedtls/ssl_internal.h
+++ b/include/mbedtls/ssl_internal.h
@@ -765,6 +765,7 @@
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1) || \
     defined(MBEDTLS_SSL_PROTO_TLS1_2)
+/* The hash buffer must have at least MBEDTLS_MD_MAX_SIZE bytes of length. */
 int mbedtls_ssl_get_key_exchange_md_tls1_2( mbedtls_ssl_context *ssl,
                                             unsigned char *hash, size_t *hashlen,
                                             unsigned char *data, size_t data_len,
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 8fe9314..28d1886 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -50,10 +50,19 @@
 
 #include <string.h>
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#include "mbedtls/psa_util.h"
+#include "psa/crypto.h"
+#endif
+
 #if defined(MBEDTLS_X509_CRT_PARSE_C)
 #include "mbedtls/oid.h"
 #endif
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#include "mbedtls/psa_util.h"
+#endif
+
 static void ssl_reset_in_out_pointers( mbedtls_ssl_context *ssl );
 static uint32_t ssl_get_hs_total_len( mbedtls_ssl_context const *ssl );
 
@@ -490,6 +499,76 @@
 #endif /* MBEDTLS_SSL_PROTO_TLS1) || MBEDTLS_SSL_PROTO_TLS1_1 */
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+static int tls_prf_generic( mbedtls_md_type_t md_type,
+                            const unsigned char *secret, size_t slen,
+                            const char *label,
+                            const unsigned char *random, size_t rlen,
+                            unsigned char *dstbuf, size_t dlen )
+{
+    psa_status_t status;
+    psa_algorithm_t alg;
+    psa_key_policy_t policy;
+    psa_key_handle_t master_slot;
+    psa_crypto_generator_t generator = PSA_CRYPTO_GENERATOR_INIT;
+
+    if( ( status = psa_allocate_key( &master_slot ) ) != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+    if( md_type == MBEDTLS_MD_SHA384 )
+        alg = PSA_ALG_TLS12_PRF(PSA_ALG_SHA_384);
+    else
+        alg = PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256);
+
+    policy = psa_key_policy_init();
+    psa_key_policy_set_usage( &policy,
+                              PSA_KEY_USAGE_DERIVE,
+                              alg );
+    status = psa_set_key_policy( master_slot, &policy );
+    if( status != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+    status = psa_import_key( master_slot, PSA_KEY_TYPE_DERIVE, secret, slen );
+    if( status != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+    status = psa_key_derivation( &generator,
+                                 master_slot, alg,
+                                 random, rlen,
+                                 (unsigned char const *) label,
+                                 (size_t) strlen( label ),
+                                 dlen );
+    if( status != PSA_SUCCESS )
+    {
+        psa_generator_abort( &generator );
+        psa_destroy_key( master_slot );
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    status = psa_generator_read( &generator, dstbuf, dlen );
+    if( status != PSA_SUCCESS )
+    {
+        psa_generator_abort( &generator );
+        psa_destroy_key( master_slot );
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    status = psa_generator_abort( &generator );
+    if( status != PSA_SUCCESS )
+    {
+        psa_destroy_key( master_slot );
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    status = psa_destroy_key( master_slot );
+    if( status != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+    return( 0 );
+}
+
+#else /* MBEDTLS_USE_PSA_CRYPTO */
+
 static int tls_prf_generic( mbedtls_md_type_t md_type,
                             const unsigned char *secret, size_t slen,
                             const char *label,
@@ -552,7 +631,7 @@
 
     return( 0 );
 }
-
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 #if defined(MBEDTLS_SHA256_C)
 static int tls_prf_sha256( const unsigned char *secret, size_t slen,
                            const char *label,
@@ -9972,6 +10051,70 @@
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1) || \
     defined(MBEDTLS_SSL_PROTO_TLS1_2)
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+int mbedtls_ssl_get_key_exchange_md_tls1_2( mbedtls_ssl_context *ssl,
+                                            unsigned char *hash, size_t *hashlen,
+                                            unsigned char *data, size_t data_len,
+                                            mbedtls_md_type_t md_alg )
+{
+    psa_status_t status;
+    psa_hash_operation_t hash_operation;
+    psa_algorithm_t hash_alg = mbedtls_psa_translate_md( md_alg );
+
+    MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based computation of digest of ServerKeyExchange" ) );
+
+    if( ( status = psa_hash_setup( &hash_operation,
+                                   hash_alg ) ) != PSA_SUCCESS )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_hash_setup", status );
+        goto exit;
+    }
+
+    if( ( status = psa_hash_update( &hash_operation, ssl->handshake->randbytes,
+                                    64 ) ) != PSA_SUCCESS )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_hash_update", status );
+        goto exit;
+    }
+
+    if( ( status = psa_hash_update( &hash_operation,
+                                    data, data_len ) ) != PSA_SUCCESS )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_hash_update", status );
+        goto exit;
+    }
+
+    if( ( status = psa_hash_finish( &hash_operation, hash, MBEDTLS_MD_MAX_SIZE,
+                                    hashlen ) ) != PSA_SUCCESS )
+    {
+         MBEDTLS_SSL_DEBUG_RET( 1, "psa_hash_finish", status );
+         goto exit;
+    }
+
+exit:
+    if( status != PSA_SUCCESS )
+    {
+        mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
+                                        MBEDTLS_SSL_ALERT_MSG_INTERNAL_ERROR );
+        switch( status )
+        {
+            case PSA_ERROR_NOT_SUPPORTED:
+                return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE );
+            case PSA_ERROR_BAD_STATE: /* Intentional fallthrough */
+            case PSA_ERROR_BUFFER_TOO_SMALL:
+                return( MBEDTLS_ERR_MD_BAD_INPUT_DATA );
+            case PSA_ERROR_INSUFFICIENT_MEMORY:
+                return( MBEDTLS_ERR_MD_ALLOC_FAILED );
+            default:
+                return( MBEDTLS_ERR_MD_HW_ACCEL_FAILED );
+        }
+    }
+    return( 0 );
+}
+
+#else
+
 int mbedtls_ssl_get_key_exchange_md_tls1_2( mbedtls_ssl_context *ssl,
                                             unsigned char *hash, size_t *hashlen,
                                             unsigned char *data, size_t data_len,
@@ -9982,6 +10125,8 @@
     const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type( md_alg );
     *hashlen = mbedtls_md_get_size( md_info );
 
+    MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform mbedtls-based computation of digest of ServerKeyExchange" ) );
+
     mbedtls_md_init( &ctx );
 
     /*
@@ -10026,6 +10171,8 @@
 
     return( ret );
 }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 #endif /* MBEDTLS_SSL_PROTO_TLS1 || MBEDTLS_SSL_PROTO_TLS1_1 || \
           MBEDTLS_SSL_PROTO_TLS1_2 */
 
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index 2ccecc4..30753b7 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -765,6 +765,7 @@
                 -C "Failed to setup PSA-based cipher context"\
                 -S "Failed to setup PSA-based cipher context"\
                 -s "Protocol is TLSv1.2" \
+                -c "Perform PSA-based computation of digest of ServerKeyExchange" \
                 -S "error" \
                 -C "error"
 }