Make use of PSA crypto hash if MBEDTLS_MD_C isn't defined

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/ecjpake.c b/library/ecjpake.c
index 10286c2..7447354 100644
--- a/library/ecjpake.c
+++ b/library/ecjpake.c
@@ -30,6 +30,13 @@
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
 
+/* We use MD first if it's available (for compatibility reasons)
+ * and "fall back" to PSA otherwise (which needs psa_crypto_init()). */
+#if !defined(MBEDTLS_MD_C)
+#include "psa/crypto.h"
+#include "mbedtls/psa_util.h"
+#endif /* !MBEDTLS_MD_C */
+
 #include "hash_info.h"
 
 #include <string.h>
@@ -47,6 +54,28 @@
 #define ID_MINE     ( ecjpake_id[ ctx->role ] )
 #define ID_PEER     ( ecjpake_id[ 1 - ctx->role ] )
 
+/**
+  * Helper to Compute a hash from md_type
+  */
+static int mbedtls_ecjpake_compute_hash( mbedtls_md_type_t md_type,
+                                    const unsigned char *input, size_t ilen,
+                                    unsigned char *output )
+{
+#if defined(MBEDTLS_MD_C)
+    return( mbedtls_md( mbedtls_md_info_from_type( md_type ),
+                        input, ilen, output ) );
+#else
+    psa_algorithm_t alg = mbedtls_psa_translate_md( md_type );
+    psa_status_t status;
+    size_t out_size = PSA_HASH_LENGTH( alg );
+    size_t out_len;
+
+    status = psa_hash_compute( alg, input, ilen, output, out_size, &out_len );
+
+    return( mbedtls_md_error_from_psa( status ) );
+#endif /* !MBEDTLS_MD_C */
+}
+
 /*
  * Initialize context
  */
@@ -106,8 +135,13 @@
 
     ctx->role = role;
 
+#if defined(MBEDTLS_MD_C)
     if( ( mbedtls_md_info_from_type( hash ) ) == NULL )
         return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE );
+#else
+    if( mbedtls_psa_translate_md( hash ) == MBEDTLS_MD_NONE )
+        return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE );
+#endif
 
     ctx->md_type = hash;
 
@@ -222,8 +256,8 @@
     p += id_len;
 
     /* Compute hash */
-    MBEDTLS_MPI_CHK( mbedtls_md( mbedtls_md_info_from_type( md_type ),
-                                 buf, p - buf, hash ) );
+    MBEDTLS_MPI_CHK( mbedtls_ecjpake_compute_hash( md_type,
+                                                   buf, p - buf, hash ) );
 
     /* Turn it into an integer mod n */
     MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( h, hash,
@@ -763,8 +797,8 @@
     /* PMS = SHA-256( K.X ) */
     x_bytes = ( ctx->grp.pbits + 7 ) / 8;
     MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &K.X, kx, x_bytes ) );
-    MBEDTLS_MPI_CHK( mbedtls_md( mbedtls_md_info_from_type( ctx->md_type ),
-                                 kx, x_bytes, buf ) );
+    MBEDTLS_MPI_CHK( mbedtls_ecjpake_compute_hash( ctx->md_type,
+                                                   kx, x_bytes, buf ) );
 
 cleanup:
     mbedtls_ecp_point_free( &K );