Use PSA hashing for LMS and LMOTS

Signed-off-by: Raef Coles <raef.coles@arm.com>
diff --git a/library/lms.c b/library/lms.c
index d8969ba..4b4f151 100644
--- a/library/lms.c
+++ b/library/lms.c
@@ -38,8 +38,9 @@
 
 #include "lmots.h"
 
+#include "psa/crypto.h"
+
 #include "mbedtls/lms.h"
-#include "mbedtls/md.h"
 #include "mbedtls/error.h"
 #include "mbedtls/platform_util.h"
 
@@ -88,59 +89,61 @@
                                     unsigned int r_node_idx,
                                     unsigned char out[32] )
 {
-    mbedtls_md_context_t hash_ctx;
+    psa_hash_operation_t op;
+    psa_status_t status;
+    size_t output_hash_len;
     unsigned char D_LEAF_bytes[D_CONST_LEN];
     unsigned char r_node_idx_bytes[4];
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    mbedtls_md_init( &hash_ctx );
-    ret = mbedtls_md_setup( &hash_ctx, mbedtls_md_info_from_type( MBEDTLS_MD_SHA256 ), 0 );
-    if( ret )
-    {
-        goto out;
-    }
-    ret = mbedtls_md_starts( &hash_ctx );
-    if( ret )
+    op = psa_hash_operation_init( );
+    status = psa_hash_setup( &op, PSA_ALG_SHA_256 );
+    ret = mbedtls_lms_error_from_psa( status );
+    if ( ret != 0 )
     {
         goto out;
     }
 
-    ret = mbedtls_md_update( &hash_ctx,
-                             ctx->MBEDTLS_PRIVATE(I_key_identifier),
-                             MBEDTLS_LMOTS_I_KEY_ID_LEN );
+    status = psa_hash_update( &op, ctx->MBEDTLS_PRIVATE(I_key_identifier),
+                              MBEDTLS_LMOTS_I_KEY_ID_LEN );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
     val_to_network_bytes( r_node_idx, 4, r_node_idx_bytes );
-    ret = mbedtls_md_update( &hash_ctx, r_node_idx_bytes, 4 );
+    status = psa_hash_update( &op, r_node_idx_bytes, 4 );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
     val_to_network_bytes( D_LEAF_CONSTANT, D_CONST_LEN, D_LEAF_bytes );
-    ret = mbedtls_md_update( &hash_ctx, D_LEAF_bytes, D_CONST_LEN );
+    status = psa_hash_update( &op, D_LEAF_bytes, D_CONST_LEN );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
-    ret = mbedtls_md_update( &hash_ctx, pub_key, MBEDTLS_LMOTS_N_HASH_LEN );
+    status = psa_hash_update( &op, pub_key, MBEDTLS_LMOTS_N_HASH_LEN );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
-    ret = mbedtls_md_finish( &hash_ctx, out );
+    status = psa_hash_finish( &op, out, 32, &output_hash_len);
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
 out:
-    mbedtls_md_free( &hash_ctx );
+    psa_hash_abort( &op );
 
     return( ret );
 }
@@ -151,64 +154,68 @@
                                     unsigned int r_node_idx,
                                     unsigned char out[32] )
 {
-    mbedtls_md_context_t hash_ctx;
+    psa_hash_operation_t op;
+    psa_status_t status;
+    size_t output_hash_len;
     unsigned char D_INTR_bytes[D_CONST_LEN];
     unsigned char r_node_idx_bytes[4];
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    mbedtls_md_init( &hash_ctx );
-    ret = mbedtls_md_setup( &hash_ctx, mbedtls_md_info_from_type( MBEDTLS_MD_SHA256 ), 0 );
-    if( ret )
-    {
-        goto out;
-    }
-    ret = mbedtls_md_starts( &hash_ctx );
-    if( ret )
+    op = psa_hash_operation_init( );
+    status = psa_hash_setup( &op, PSA_ALG_SHA_256 );
+    ret = mbedtls_lms_error_from_psa( status );
+    if ( ret != 0 )
     {
         goto out;
     }
 
-    ret = mbedtls_md_update( &hash_ctx, ctx->MBEDTLS_PRIVATE(I_key_identifier),
-                             MBEDTLS_LMOTS_I_KEY_ID_LEN );
+    status = psa_hash_update( &op, ctx->MBEDTLS_PRIVATE(I_key_identifier),
+                              MBEDTLS_LMOTS_I_KEY_ID_LEN );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
     val_to_network_bytes( r_node_idx, 4, r_node_idx_bytes );
-    ret = mbedtls_md_update( &hash_ctx, r_node_idx_bytes, 4 );
+    status = psa_hash_update( &op, r_node_idx_bytes, 4 );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
     val_to_network_bytes( D_INTR_CONSTANT, D_CONST_LEN, D_INTR_bytes );
-    ret = mbedtls_md_update( &hash_ctx, D_INTR_bytes, D_CONST_LEN );
+    status = psa_hash_update( &op, D_INTR_bytes, D_CONST_LEN );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
-    ret = mbedtls_md_update( &hash_ctx, left_node, MBEDTLS_LMOTS_N_HASH_LEN );
+    status = psa_hash_update( &op, left_node, MBEDTLS_LMOTS_N_HASH_LEN );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
-    ret = mbedtls_md_update( &hash_ctx, rght_node, MBEDTLS_LMOTS_N_HASH_LEN );
+    status = psa_hash_update( &op, rght_node, MBEDTLS_LMOTS_N_HASH_LEN );
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
-    ret = mbedtls_md_finish( &hash_ctx, out );
+    ret = psa_hash_finish( &op, out, 32, &output_hash_len);
+    ret = mbedtls_lms_error_from_psa( status );
     if( ret )
     {
         goto out;
     }
 
 out:
-    mbedtls_md_free( &hash_ctx );
+    psa_hash_abort( &op );
 
     return ret;
 }