Make the mbedtls_psa_hkdf_extract function more PSA compatible

Change the return value to `psa_status_t`.
Add `prk_size` and `prk_len` parameters.

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ssl_tls13_invasive.h b/library/ssl_tls13_invasive.h
index e3b1dc7..9f30c4a 100644
--- a/library/ssl_tls13_invasive.h
+++ b/library/ssl_tls13_invasive.h
@@ -28,10 +28,11 @@
 
 #if defined(MBEDTLS_PSA_CRYPTO_C)
 
-int mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
-                              const unsigned char *salt, size_t salt_len,
-                              const unsigned char *ikm, size_t ikm_len,
-                              unsigned char *prk );
+psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
+                                       const unsigned char *salt, size_t salt_len,
+                                       const unsigned char *ikm, size_t ikm_len,
+                                       unsigned char *prk, size_t prk_size,
+                                       size_t *prk_len );
 
 /**
  *  \brief  Expand the supplied \p prk into several additional pseudorandom
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index e63f83a..ad794e6 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -139,16 +139,16 @@
 #if defined( MBEDTLS_TEST_HOOKS )
 
 MBEDTLS_STATIC_TESTABLE
-int mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
-                              const unsigned char *salt, size_t salt_len,
-                              const unsigned char *ikm, size_t ikm_len,
-                              unsigned char *prk )
+psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
+                                       const unsigned char *salt, size_t salt_len,
+                                       const unsigned char *ikm, size_t ikm_len,
+                                       unsigned char *prk, size_t prk_size,
+                                       size_t *prk_len )
 {
     unsigned char null_salt[PSA_MAC_MAX_SIZE] = { '\0' };
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
-    size_t prk_len;
-    int ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
+    psa_status_t ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
 
     if( salt == NULL || salt_len == 0 )
     {
@@ -181,7 +181,7 @@
         goto cleanup;
     }
 
-    ret = psa_mac_compute( key, alg, ikm, ikm_len, prk, PSA_HASH_LENGTH( alg ), &prk_len );
+    ret = psa_mac_compute( key, alg, ikm, ikm_len, prk, prk_size, prk_len );
 
 cleanup:
     psa_destroy_key( key );
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 0122d46..c8b70a3 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -3814,10 +3814,10 @@
     unsigned char *salt = NULL;
     unsigned char *prk = NULL;
     unsigned char *output_prk = NULL;
-    size_t ikm_len, salt_len, prk_len, output_prk_len;
+    size_t ikm_len, salt_len, prk_len, output_prk_size, output_prk_len;
 
-    output_prk_len = PSA_HASH_LENGTH( alg );
-    output_prk = mbedtls_calloc( 1, output_prk_len );
+    output_prk_size = PSA_HASH_LENGTH( alg );
+    output_prk = mbedtls_calloc( 1, output_prk_size );
 
     ikm = mbedtls_test_unhexify_alloc( hex_ikm_string, &ikm_len );
     salt = mbedtls_test_unhexify_alloc( hex_salt_string, &salt_len );
@@ -3825,7 +3825,9 @@
 
     PSA_ASSERT( psa_crypto_init() );
     PSA_ASSERT( mbedtls_psa_hkdf_extract( alg, salt, salt_len,
-                                          ikm, ikm_len, output_prk ) );
+                                          ikm, ikm_len,
+                                          output_prk, output_prk_size,
+                                          &output_prk_len ) );
 
     ASSERT_COMPARE( output_prk, output_prk_len, prk, prk_len );
 
@@ -3846,16 +3848,19 @@
     unsigned char *salt = NULL;
     unsigned char *ikm = NULL;
     unsigned char *prk = NULL;
-    size_t salt_len, ikm_len;
+    size_t salt_len, ikm_len, prk_len;
 
     prk = mbedtls_calloc( PSA_MAC_MAX_SIZE, 1 );
     salt_len = hash_len;
     ikm_len = 0;
+    prk_len = 0;
 
     PSA_ASSERT( psa_crypto_init() );
     output_ret = mbedtls_psa_hkdf_extract( 0, salt, salt_len,
-                                           ikm, ikm_len, prk );
+                                           ikm, ikm_len,
+                                           prk, PSA_MAC_MAX_SIZE, &prk_len );
     TEST_ASSERT( output_ret == ret );
+    TEST_ASSERT( prk_len == 0 );
 
 exit:
     mbedtls_free(prk);