Unify similar functions

Use common funtion for psa_sign_hash and psa_sign_message and one for
psa_verify_hash and psa_verify_message to unify them.

Signed-off-by: gabor-mezei-arm <gabor.mezei@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index c21e03b..4cb6ff3 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -2842,26 +2842,51 @@
 /* Asymmetric cryptography */
 /****************************************************************/
 
-psa_status_t psa_sign_message( mbedtls_svc_key_id_t key,
-                               psa_algorithm_t alg,
-                               const uint8_t * input,
-                               size_t input_length,
-                               uint8_t * signature,
-                               size_t signature_size,
-                               size_t * signature_length )
+typedef enum
+{
+    PSA_SIGN_INVALID = 0,
+    PSA_SIGN_HASH = 1,
+    PSA_SIGN_MESSAGE
+} psa_sign_operation_t;
+
+typedef enum
+{
+    PSA_VERIFY_INVALID = 0,
+    PSA_VERIFY_HASH = 1,
+    PSA_VERIFY_MESSAGE
+} psa_verify_operation_t;
+
+static psa_status_t psa_sign_internal( mbedtls_svc_key_id_t key,
+                                       psa_sign_operation_t operation,
+                                       psa_algorithm_t alg,
+                                       const uint8_t * input,
+                                       size_t input_length,
+                                       uint8_t * signature,
+                                       size_t signature_size,
+                                       size_t * signature_length )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
-    size_t hash_length;
-    uint8_t hash[PSA_HASH_MAX_SIZE];
 
     *signature_length = 0;
 
-    if( ! PSA_ALG_IS_SIGN_MESSAGE( alg ) )
-        return( PSA_ERROR_INVALID_ARGUMENT );
+    if( operation == PSA_SIGN_MESSAGE )
+    {
+        if( ! PSA_ALG_IS_SIGN_MESSAGE( alg ) )
+            return( PSA_ERROR_INVALID_ARGUMENT );
 
-    if ( ! PSA_ALG_IS_HASH( PSA_ALG_SIGN_GET_HASH( alg ) ) )
+        if ( PSA_ALG_IS_HASH_AND_SIGN( alg ) )
+        {
+            if( ! PSA_ALG_IS_HASH( PSA_ALG_SIGN_GET_HASH( alg ) ) )
+                return( PSA_ERROR_INVALID_ARGUMENT );
+        }
+        /* Curently only hash-then-sign algorithms are supported. */
+        else
+            return( PSA_ERROR_INVALID_ARGUMENT );
+    }
+
+    else if( operation == PSA_SIGN_INVALID )
         return( PSA_ERROR_INVALID_ARGUMENT );
 
     /* Immediately reject a zero-length signature buffer. This guarantees
@@ -2871,9 +2896,12 @@
     if( signature_size == 0 )
         return( PSA_ERROR_BUFFER_TOO_SMALL );
 
-    status = psa_get_and_lock_key_slot_with_policy( key, &slot,
-                                                    PSA_KEY_USAGE_SIGN_MESSAGE,
-                                                    alg );
+    status = psa_get_and_lock_key_slot_with_policy(
+                key, &slot,
+                operation == PSA_SIGN_HASH ? PSA_KEY_USAGE_SIGN_HASH :
+                                             PSA_KEY_USAGE_SIGN_MESSAGE,
+                alg );
+
     if( status != PSA_SUCCESS )
         goto exit;
 
@@ -2887,23 +2915,33 @@
       .core = slot->attr
     };
 
-    status = psa_driver_wrapper_hash_compute( PSA_ALG_SIGN_GET_HASH( alg ),
-                                              input, input_length,
-                                              hash, sizeof( hash ),
-                                              &hash_length );
-
-    if( status != PSA_SUCCESS )
+    if( operation == PSA_SIGN_MESSAGE )
     {
-        memset( hash, 0, sizeof( hash ) );
-        goto exit;
+        size_t hash_length;
+        uint8_t hash[PSA_HASH_MAX_SIZE];
+
+        status = psa_driver_wrapper_hash_compute( PSA_ALG_SIGN_GET_HASH( alg ),
+                                                  input, input_length,
+                                                  hash, sizeof( hash ),
+                                                  &hash_length );
+
+        if( status != PSA_SUCCESS )
+            goto exit;
+
+        status = psa_driver_wrapper_sign_hash(
+            &attributes, slot->key.data, slot->key.bytes,
+            alg, hash, hash_length,
+            signature, signature_size, signature_length );
+    }
+    else if( operation == PSA_SIGN_HASH )
+    {
+
+        status = psa_driver_wrapper_sign_hash(
+            &attributes, slot->key.data, slot->key.bytes,
+            alg, input, input_length,
+            signature, signature_size, signature_length );
     }
 
-    status = psa_driver_wrapper_sign_hash(
-        &attributes, slot->key.data, slot->key.bytes,
-        alg, hash, hash_length,
-        signature, signature_size, signature_length );
-
-    memset( hash, 0, hash_length );
 
 exit:
     /* Fill the unused part of the output buffer (the whole buffer on error,
@@ -2923,28 +2961,42 @@
     return( ( status == PSA_SUCCESS ) ? unlock_status : status );
 }
 
-psa_status_t psa_verify_message( mbedtls_svc_key_id_t key,
-                                 psa_algorithm_t alg,
-                                 const uint8_t * input,
-                                 size_t input_length,
-                                 const uint8_t * signature,
-                                 size_t signature_length )
+static psa_status_t psa_verify_internal( mbedtls_svc_key_id_t key,
+                                         psa_verify_operation_t operation,
+                                         psa_algorithm_t alg,
+                                         const uint8_t * input,
+                                         size_t input_length,
+                                         const uint8_t * signature,
+                                         size_t signature_length )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
-    size_t hash_length;
-    uint8_t hash[PSA_HASH_MAX_SIZE];
 
-    if( ! PSA_ALG_IS_SIGN_MESSAGE( alg ) )
+    if( operation == PSA_VERIFY_MESSAGE )
+    {
+        if( ! PSA_ALG_IS_SIGN_MESSAGE( alg ) )
+            return( PSA_ERROR_INVALID_ARGUMENT );
+
+        if ( PSA_ALG_IS_HASH_AND_SIGN( alg ) )
+        {
+            if( ! PSA_ALG_IS_HASH( PSA_ALG_SIGN_GET_HASH( alg ) ) )
+                return( PSA_ERROR_INVALID_ARGUMENT );
+        }
+        /* Curently only hash-then-sign algorithms are supported. */
+        else
+            return( PSA_ERROR_INVALID_ARGUMENT );
+    }
+
+    else if( operation == PSA_VERIFY_INVALID )
         return( PSA_ERROR_INVALID_ARGUMENT );
 
-    if ( ! PSA_ALG_IS_HASH( PSA_ALG_SIGN_GET_HASH( alg ) ) )
-        return( PSA_ERROR_INVALID_ARGUMENT );
+    status = psa_get_and_lock_key_slot_with_policy(
+                key, &slot,
+                operation == PSA_VERIFY_HASH ? PSA_KEY_USAGE_VERIFY_HASH :
+                                               PSA_KEY_USAGE_VERIFY_MESSAGE,
+                alg );
 
-    status = psa_get_and_lock_key_slot_with_policy( key, &slot,
-                                                    PSA_KEY_USAGE_VERIFY_MESSAGE,
-                                                    alg );
     if( status != PSA_SUCCESS )
         return( status );
 
@@ -2952,28 +3004,62 @@
       .core = slot->attr
     };
 
-    status = psa_driver_wrapper_hash_compute( PSA_ALG_SIGN_GET_HASH( alg ),
-                                              input, input_length,
-                                              hash, sizeof( hash ),
-                                              &hash_length );
-
-    if( status != PSA_SUCCESS )
+    if( operation == PSA_VERIFY_MESSAGE )
     {
-        memset( hash, 0, sizeof( hash ) );
-        goto exit;
+        size_t hash_length;
+        uint8_t hash[PSA_HASH_MAX_SIZE];
+
+        status = psa_driver_wrapper_hash_compute( PSA_ALG_SIGN_GET_HASH( alg ),
+                                                  input, input_length,
+                                                  hash, sizeof( hash ),
+                                                  &hash_length );
+
+        if( status != PSA_SUCCESS )
+            goto exit;
+
+        status = psa_driver_wrapper_verify_hash(
+            &attributes, slot->key.data, slot->key.bytes,
+            alg, hash, hash_length,
+            signature, signature_length );
     }
-
-    status = psa_driver_wrapper_verify_hash(
-        &attributes, slot->key.data, slot->key.bytes,
-        alg, hash, hash_length,
-        signature, signature_length );
-
-    memset( hash, 0, hash_length );
+    else if( operation == PSA_VERIFY_HASH )
+    {
+        status = psa_driver_wrapper_verify_hash(
+            &attributes, slot->key.data, slot->key.bytes,
+            alg, input, input_length,
+            signature, signature_length );
+    }
 
 exit:
     unlock_status = psa_unlock_key_slot( slot );
 
     return( ( status == PSA_SUCCESS ) ? unlock_status : status );
+
+}
+
+psa_status_t psa_sign_message( mbedtls_svc_key_id_t key,
+                               psa_algorithm_t alg,
+                               const uint8_t * input,
+                               size_t input_length,
+                               uint8_t * signature,
+                               size_t signature_size,
+                               size_t * signature_length )
+{
+    return psa_sign_internal(
+        key, PSA_SIGN_MESSAGE, alg, input, input_length,
+        signature, signature_size, signature_length );
+}
+
+psa_status_t psa_verify_message( mbedtls_svc_key_id_t key,
+                                 psa_algorithm_t alg,
+                                 const uint8_t * input,
+                                 size_t input_length,
+                                 const uint8_t * signature,
+                                 size_t signature_length )
+{
+    return psa_verify_internal(
+        key, PSA_VERIFY_MESSAGE, alg, input, input_length,
+        signature, signature_length );
 }
 
 psa_status_t psa_sign_hash_internal(
@@ -3042,54 +3128,9 @@
                             size_t signature_size,
                             size_t *signature_length )
 {
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_key_slot_t *slot;
-
-    *signature_length = signature_size;
-    /* Immediately reject a zero-length signature buffer. This guarantees
-     * that signature must be a valid pointer. (On the other hand, the hash
-     * buffer can in principle be empty since it doesn't actually have
-     * to be a hash.) */
-    if( signature_size == 0 )
-        return( PSA_ERROR_BUFFER_TOO_SMALL );
-
-    status = psa_get_and_lock_key_slot_with_policy( key, &slot,
-                                                    PSA_KEY_USAGE_SIGN_HASH,
-                                                    alg );
-    if( status != PSA_SUCCESS )
-        goto exit;
-    if( ! PSA_KEY_TYPE_IS_KEY_PAIR( slot->attr.type ) )
-    {
-        status = PSA_ERROR_INVALID_ARGUMENT;
-        goto exit;
-    }
-
-    psa_key_attributes_t attributes = {
-      .core = slot->attr
-    };
-
-    status = psa_driver_wrapper_sign_hash(
-        &attributes, slot->key.data, slot->key.bytes,
-        alg, hash, hash_length,
+    return psa_sign_internal(
+        key, PSA_SIGN_HASH, alg, hash, hash_length,
         signature, signature_size, signature_length );
-
-exit:
-    /* Fill the unused part of the output buffer (the whole buffer on error,
-     * the trailing part on success) with something that isn't a valid mac
-     * (barring an attack on the mac and deliberately-crafted input),
-     * in case the caller doesn't check the return status properly. */
-    if( status == PSA_SUCCESS )
-        memset( signature + *signature_length, '!',
-                signature_size - *signature_length );
-    else
-        memset( signature, '!', signature_size );
-    /* If signature_size is 0 then we have nothing to do. We must not call
-     * memset because signature may be NULL in this case. */
-
-    unlock_status = psa_unlock_key_slot( slot );
-
-    return( ( status == PSA_SUCCESS ) ? unlock_status : status );
 }
 
 psa_status_t psa_verify_hash_internal(
@@ -3156,28 +3197,9 @@
                               const uint8_t *signature,
                               size_t signature_length )
 {
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_key_slot_t *slot;
-
-    status = psa_get_and_lock_key_slot_with_policy( key, &slot,
-                                                    PSA_KEY_USAGE_VERIFY_HASH,
-                                                    alg );
-    if( status != PSA_SUCCESS )
-        return( status );
-
-    psa_key_attributes_t attributes = {
-      .core = slot->attr
-    };
-
-    status = psa_driver_wrapper_verify_hash(
-        &attributes, slot->key.data, slot->key.bytes,
-        alg, hash, hash_length,
+    return psa_verify_internal(
+        key, PSA_VERIFY_HASH, alg, hash, hash_length,
         signature, signature_length );
-
-    unlock_status = psa_unlock_key_slot( slot );
-
-    return( ( status == PSA_SUCCESS ) ? unlock_status : status );
 }
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)