Split hashing operations out into an mbedTLS hash driver

Signed-off-by: Steven Cooreman <steven.cooreman@silabs.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 6225272..84cf32d 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -33,6 +33,7 @@
 #include "psa_crypto_invasive.h"
 #include "psa_crypto_driver_wrappers.h"
 #include "psa_crypto_ecp.h"
+#include "psa_crypto_hash.h"
 #include "psa_crypto_rsa.h"
 #include "psa_crypto_ecp.h"
 #if defined(MBEDTLS_PSA_CRYPTO_SE_C)
@@ -2196,219 +2197,58 @@
 
 psa_status_t psa_hash_abort( psa_hash_operation_t *operation )
 {
-    switch( operation->alg )
+    if( operation != NULL )
     {
-        case 0:
-            /* The object has (apparently) been initialized but it is not
-             * in use. It's ok to call abort on such an object, and there's
-             * nothing to do. */
-            break;
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD2)
-        case PSA_ALG_MD2:
-            mbedtls_md2_free( &operation->ctx.md2 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD4)
-        case PSA_ALG_MD4:
-            mbedtls_md4_free( &operation->ctx.md4 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD5)
-        case PSA_ALG_MD5:
-            mbedtls_md5_free( &operation->ctx.md5 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RIPEMD160)
-        case PSA_ALG_RIPEMD160:
-            mbedtls_ripemd160_free( &operation->ctx.ripemd160 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_1)
-        case PSA_ALG_SHA_1:
-            mbedtls_sha1_free( &operation->ctx.sha1 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_224)
-        case PSA_ALG_SHA_224:
-            mbedtls_sha256_free( &operation->ctx.sha256 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_256)
-        case PSA_ALG_SHA_256:
-            mbedtls_sha256_free( &operation->ctx.sha256 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_384)
-        case PSA_ALG_SHA_384:
-            mbedtls_sha512_free( &operation->ctx.sha512 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_512)
-        case PSA_ALG_SHA_512:
-            mbedtls_sha512_free( &operation->ctx.sha512 );
-            break;
-#endif
-        default:
-            return( PSA_ERROR_BAD_STATE );
+        if( operation->ctx.ctx != NULL )
+        {
+            psa_status_t status = mbedtls_psa_hash_abort( operation->ctx.ctx );
+            mbedtls_free( operation->ctx.ctx );
+            operation->ctx.ctx = NULL;
+            return( status );
+        }
+        else
+        {
+            // Multiple consequent calls to abort return success
+            return( PSA_SUCCESS );
+        }
     }
-    operation->alg = 0;
-    return( PSA_SUCCESS );
+    else
+        return( PSA_ERROR_INVALID_ARGUMENT );
 }
 
 psa_status_t psa_hash_setup( psa_hash_operation_t *operation,
                              psa_algorithm_t alg )
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    if( operation == NULL || !PSA_ALG_IS_HASH( alg ) )
+        return( PSA_ERROR_INVALID_ARGUMENT );
 
     /* A context must be freshly initialized before it can be set up. */
-    if( operation->alg != 0 )
-    {
+    if( operation->ctx.ctx != NULL )
         return( PSA_ERROR_BAD_STATE );
-    }
 
-    switch( alg )
-    {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD2)
-        case PSA_ALG_MD2:
-            mbedtls_md2_init( &operation->ctx.md2 );
-            ret = mbedtls_md2_starts_ret( &operation->ctx.md2 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD4)
-        case PSA_ALG_MD4:
-            mbedtls_md4_init( &operation->ctx.md4 );
-            ret = mbedtls_md4_starts_ret( &operation->ctx.md4 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD5)
-        case PSA_ALG_MD5:
-            mbedtls_md5_init( &operation->ctx.md5 );
-            ret = mbedtls_md5_starts_ret( &operation->ctx.md5 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RIPEMD160)
-        case PSA_ALG_RIPEMD160:
-            mbedtls_ripemd160_init( &operation->ctx.ripemd160 );
-            ret = mbedtls_ripemd160_starts_ret( &operation->ctx.ripemd160 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_1)
-        case PSA_ALG_SHA_1:
-            mbedtls_sha1_init( &operation->ctx.sha1 );
-            ret = mbedtls_sha1_starts_ret( &operation->ctx.sha1 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_224)
-        case PSA_ALG_SHA_224:
-            mbedtls_sha256_init( &operation->ctx.sha256 );
-            ret = mbedtls_sha256_starts_ret( &operation->ctx.sha256, 1 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_256)
-        case PSA_ALG_SHA_256:
-            mbedtls_sha256_init( &operation->ctx.sha256 );
-            ret = mbedtls_sha256_starts_ret( &operation->ctx.sha256, 0 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_384)
-        case PSA_ALG_SHA_384:
-            mbedtls_sha512_init( &operation->ctx.sha512 );
-            ret = mbedtls_sha512_starts_ret( &operation->ctx.sha512, 1 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_512)
-        case PSA_ALG_SHA_512:
-            mbedtls_sha512_init( &operation->ctx.sha512 );
-            ret = mbedtls_sha512_starts_ret( &operation->ctx.sha512, 0 );
-            break;
-#endif
-        default:
-            return( PSA_ALG_IS_HASH( alg ) ?
-                    PSA_ERROR_NOT_SUPPORTED :
-                    PSA_ERROR_INVALID_ARGUMENT );
-    }
-    if( ret == 0 )
-        operation->alg = alg;
-    else
+    operation->ctx.ctx = mbedtls_calloc( 1, sizeof(mbedtls_psa_hash_operation_t) );
+    status = mbedtls_psa_hash_setup( operation->ctx.ctx, alg );
+    if( status != PSA_SUCCESS )
         psa_hash_abort( operation );
-    return( mbedtls_to_psa_error( ret ) );
+    return( status );
 }
 
 psa_status_t psa_hash_update( psa_hash_operation_t *operation,
                               const uint8_t *input,
                               size_t input_length )
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    if( operation == NULL )
+        return( PSA_ERROR_INVALID_ARGUMENT );
+    if( operation->ctx.ctx == NULL )
+        return( PSA_ERROR_BAD_STATE );
 
-    /* Don't require hash implementations to behave correctly on a
-     * zero-length input, which may have an invalid pointer. */
-    if( input_length == 0 )
-        return( PSA_SUCCESS );
-
-    switch( operation->alg )
-    {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD2)
-        case PSA_ALG_MD2:
-            ret = mbedtls_md2_update_ret( &operation->ctx.md2,
-                                          input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD4)
-        case PSA_ALG_MD4:
-            ret = mbedtls_md4_update_ret( &operation->ctx.md4,
-                                          input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD5)
-        case PSA_ALG_MD5:
-            ret = mbedtls_md5_update_ret( &operation->ctx.md5,
-                                          input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RIPEMD160)
-        case PSA_ALG_RIPEMD160:
-            ret = mbedtls_ripemd160_update_ret( &operation->ctx.ripemd160,
-                                                input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_1)
-        case PSA_ALG_SHA_1:
-            ret = mbedtls_sha1_update_ret( &operation->ctx.sha1,
-                                           input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_224)
-        case PSA_ALG_SHA_224:
-            ret = mbedtls_sha256_update_ret( &operation->ctx.sha256,
-                                             input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_256)
-        case PSA_ALG_SHA_256:
-            ret = mbedtls_sha256_update_ret( &operation->ctx.sha256,
-                                             input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_384)
-        case PSA_ALG_SHA_384:
-            ret = mbedtls_sha512_update_ret( &operation->ctx.sha512,
-                                             input, input_length );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_512)
-        case PSA_ALG_SHA_512:
-            ret = mbedtls_sha512_update_ret( &operation->ctx.sha512,
-                                             input, input_length );
-            break;
-#endif
-        default:
-            (void)input;
-            return( PSA_ERROR_BAD_STATE );
-    }
-
-    if( ret != 0 )
+    psa_status_t status = mbedtls_psa_hash_update( operation->ctx.ctx,
+                                                   input, input_length );
+    if( status != PSA_SUCCESS )
         psa_hash_abort( operation );
-    return( mbedtls_to_psa_error( ret ) );
+    return( status );
 }
 
 psa_status_t psa_hash_finish( psa_hash_operation_t *operation,
@@ -2416,88 +2256,15 @@
                               size_t hash_size,
                               size_t *hash_length )
 {
-    psa_status_t status;
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t actual_hash_length = PSA_HASH_LENGTH( operation->alg );
+    if( operation == NULL )
+        return( PSA_ERROR_INVALID_ARGUMENT );
+    if( operation->ctx.ctx == NULL )
+        return( PSA_ERROR_BAD_STATE );
 
-    /* Fill the output buffer with something that isn't a valid hash
-     * (barring an attack on the hash and deliberately-crafted input),
-     * in case the caller doesn't check the return status properly. */
-    *hash_length = hash_size;
-    /* If hash_size is 0 then hash may be NULL and then the
-     * call to memset would have undefined behavior. */
-    if( hash_size != 0 )
-        memset( hash, '!', hash_size );
-
-    if( hash_size < actual_hash_length )
-    {
-        status = PSA_ERROR_BUFFER_TOO_SMALL;
-        goto exit;
-    }
-
-    switch( operation->alg )
-    {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD2)
-        case PSA_ALG_MD2:
-            ret = mbedtls_md2_finish_ret( &operation->ctx.md2, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD4)
-        case PSA_ALG_MD4:
-            ret = mbedtls_md4_finish_ret( &operation->ctx.md4, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD5)
-        case PSA_ALG_MD5:
-            ret = mbedtls_md5_finish_ret( &operation->ctx.md5, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RIPEMD160)
-        case PSA_ALG_RIPEMD160:
-            ret = mbedtls_ripemd160_finish_ret( &operation->ctx.ripemd160, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_1)
-        case PSA_ALG_SHA_1:
-            ret = mbedtls_sha1_finish_ret( &operation->ctx.sha1, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_224)
-        case PSA_ALG_SHA_224:
-            ret = mbedtls_sha256_finish_ret( &operation->ctx.sha256, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_256)
-        case PSA_ALG_SHA_256:
-            ret = mbedtls_sha256_finish_ret( &operation->ctx.sha256, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_384)
-        case PSA_ALG_SHA_384:
-            ret = mbedtls_sha512_finish_ret( &operation->ctx.sha512, hash );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_512)
-        case PSA_ALG_SHA_512:
-            ret = mbedtls_sha512_finish_ret( &operation->ctx.sha512, hash );
-            break;
-#endif
-        default:
-            return( PSA_ERROR_BAD_STATE );
-    }
-    status = mbedtls_to_psa_error( ret );
-
-exit:
-    if( status == PSA_SUCCESS )
-    {
-        *hash_length = actual_hash_length;
-        return( psa_hash_abort( operation ) );
-    }
-    else
-    {
-        psa_hash_abort( operation );
-        return( status );
-    }
+    psa_status_t status = mbedtls_psa_hash_finish( operation->ctx.ctx,
+                                                   hash, hash_size, hash_length );
+    psa_hash_abort( operation );
+    return( status );
 }
 
 psa_status_t psa_hash_verify( psa_hash_operation_t *operation,
@@ -2523,26 +2290,8 @@
                                uint8_t *hash, size_t hash_size,
                                size_t *hash_length )
 {
-    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-
-    *hash_length = hash_size;
-    status = psa_hash_setup( &operation, alg );
-    if( status != PSA_SUCCESS )
-        goto exit;
-    status = psa_hash_update( &operation, input, input_length );
-    if( status != PSA_SUCCESS )
-        goto exit;
-    status = psa_hash_finish( &operation, hash, hash_size, hash_length );
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-exit:
-    if( status == PSA_SUCCESS )
-        status = psa_hash_abort( &operation );
-    else
-        psa_hash_abort( &operation );
-    return( status );
+    return( mbedtls_psa_hash_compute( alg, input, input_length,
+                                      hash, hash_size, hash_length ) );
 }
 
 psa_status_t psa_hash_compare( psa_algorithm_t alg,
@@ -2573,73 +2322,15 @@
 psa_status_t psa_hash_clone( const psa_hash_operation_t *source_operation,
                              psa_hash_operation_t *target_operation )
 {
-    if( target_operation->alg != 0 )
+    if( source_operation == NULL || target_operation == NULL )
+        return( PSA_ERROR_INVALID_ARGUMENT );
+    if( source_operation->ctx.ctx == NULL )
+        return( PSA_ERROR_BAD_STATE );
+    if( target_operation->ctx.ctx != NULL )
         return( PSA_ERROR_BAD_STATE );
 
-    switch( source_operation->alg )
-    {
-        case 0:
-            return( PSA_ERROR_BAD_STATE );
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD2)
-        case PSA_ALG_MD2:
-            mbedtls_md2_clone( &target_operation->ctx.md2,
-                               &source_operation->ctx.md2 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD4)
-        case PSA_ALG_MD4:
-            mbedtls_md4_clone( &target_operation->ctx.md4,
-                               &source_operation->ctx.md4 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_MD5)
-        case PSA_ALG_MD5:
-            mbedtls_md5_clone( &target_operation->ctx.md5,
-                               &source_operation->ctx.md5 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RIPEMD160)
-        case PSA_ALG_RIPEMD160:
-            mbedtls_ripemd160_clone( &target_operation->ctx.ripemd160,
-                                     &source_operation->ctx.ripemd160 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_1)
-        case PSA_ALG_SHA_1:
-            mbedtls_sha1_clone( &target_operation->ctx.sha1,
-                                &source_operation->ctx.sha1 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_224)
-        case PSA_ALG_SHA_224:
-            mbedtls_sha256_clone( &target_operation->ctx.sha256,
-                                  &source_operation->ctx.sha256 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_256)
-        case PSA_ALG_SHA_256:
-            mbedtls_sha256_clone( &target_operation->ctx.sha256,
-                                  &source_operation->ctx.sha256 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_384)
-        case PSA_ALG_SHA_384:
-            mbedtls_sha512_clone( &target_operation->ctx.sha512,
-                                  &source_operation->ctx.sha512 );
-            break;
-#endif
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_SHA_512)
-        case PSA_ALG_SHA_512:
-            mbedtls_sha512_clone( &target_operation->ctx.sha512,
-                                  &source_operation->ctx.sha512 );
-            break;
-#endif
-        default:
-            return( PSA_ERROR_NOT_SUPPORTED );
-    }
-
-    target_operation->alg = source_operation->alg;
-    return( PSA_SUCCESS );
+    target_operation->ctx.ctx = mbedtls_calloc(1, sizeof(mbedtls_psa_hash_operation_t));
+    return( mbedtls_psa_hash_clone( source_operation->ctx.ctx, target_operation->ctx.ctx ) );
 }
 
 
@@ -2795,7 +2486,7 @@
     if( PSA_ALG_IS_HMAC( operation->alg ) )
     {
         /* We'll set up the hash operation later in psa_hmac_setup_internal. */
-        operation->ctx.hmac.hash_ctx.alg = 0;
+        operation->ctx.hmac.alg = 0;
         status = PSA_SUCCESS;
     }
     else
@@ -2902,6 +2593,8 @@
     size_t block_size = psa_get_hash_block_size( hash_alg );
     psa_status_t status;
 
+    hmac->alg = hash_alg;
+
     /* Sanity checks on block_size, to guarantee that there won't be a buffer
      * overflow below. This should never trigger if the hash algorithm
      * is implemented correctly. */
@@ -3119,7 +2812,7 @@
                                               size_t mac_size )
 {
     uint8_t tmp[MBEDTLS_MD_MAX_SIZE];
-    psa_algorithm_t hash_alg = hmac->hash_ctx.alg;
+    psa_algorithm_t hash_alg = hmac->alg;
     size_t hash_size = 0;
     size_t block_size = psa_get_hash_block_size( hash_alg );
     psa_status_t status;