Split multipart AEAD contexts into two parts

Split to data required for internal implementation and data required for
driver implementation with data left over for the PSA layer.

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 4ab0c63..7190aa4 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3214,6 +3214,25 @@
     return( status );
 }
 
+/* Helper function to get the base algorithm from its variants. */
+static psa_algorithm_t psa_aead_get_base_algorithm(psa_algorithm_t alg)
+{
+    switch( PSA_ALG_AEAD_WITH_SHORTENED_TAG( alg, 0 ) )
+    {
+        case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 0 ):
+            return( PSA_ALG_CCM );
+
+        case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_GCM, 0 ):
+            return( PSA_ALG_GCM );
+
+        case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CHACHA20_POLY1305, 0 ):
+            return( PSA_ALG_CHACHA20_POLY1305 );
+
+        default:
+            return( PSA_ERROR_NOT_SUPPORTED );
+    }
+}
+
 /* Set the key for a multipart authenticated encryption operation. */
 psa_status_t psa_aead_encrypt_setup( psa_aead_operation_t *operation,
                                      mbedtls_svc_key_id_t key,
@@ -3226,6 +3245,12 @@
     if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
         return( PSA_ERROR_NOT_SUPPORTED );
 
+    if( operation->key_set || operation->nonce_set ||
+        operation->ad_started || operation->body_started )
+    {
+        return( PSA_ERROR_BAD_STATE );
+    }
+
     status = psa_get_and_lock_key_slot_with_policy(
                  key, &slot, PSA_KEY_USAGE_ENCRYPT, alg );
 
@@ -3242,6 +3267,7 @@
                                                     &attributes, slot->key.data,
                                                     slot->key.bytes, alg );
 
+    operation->key_type = psa_get_key_type( &attributes );
 
     unlock_status = psa_unlock_key_slot( slot );
 
@@ -3250,6 +3276,12 @@
         return( unlock_status );
     }
 
+    if( status == PSA_SUCCESS )
+    {
+        operation->alg = psa_aead_get_base_algorithm( alg );
+        operation->key_set = 1;
+    }
+
     return( status );
 }
 
@@ -3265,6 +3297,12 @@
     if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
         return( PSA_ERROR_NOT_SUPPORTED );
 
+    if( operation->key_set || operation->nonce_set ||
+        operation->ad_started || operation->body_started )
+    {
+        return( PSA_ERROR_BAD_STATE );
+    }
+
     status = psa_get_and_lock_key_slot_with_policy(
                  key, &slot, PSA_KEY_USAGE_DECRYPT, alg );
 
@@ -3281,6 +3319,7 @@
                                                     &attributes, slot->key.data,
                                                     slot->key.bytes, alg );
 
+    operation->key_type = psa_get_key_type( &attributes );
 
     unlock_status = psa_unlock_key_slot( slot );
 
@@ -3289,6 +3328,12 @@
         return( unlock_status );
     }
 
+    if( status == PSA_SUCCESS )
+    {
+        operation->alg = psa_aead_get_base_algorithm( alg );
+        operation->key_set = 1;
+    }
+
     return( status );
 }
 
@@ -3341,14 +3386,23 @@
                                  const uint8_t *nonce,
                                  size_t nonce_length )
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
     if( !operation->key_set || operation->nonce_set ||
         operation->ad_started || operation->body_started )
     {
         return( PSA_ERROR_BAD_STATE );
     }
 
-    return( psa_driver_wrapper_aead_set_nonce( operation, nonce,
-                                               nonce_length ) );
+    status = psa_driver_wrapper_aead_set_nonce( operation, nonce,
+                                                nonce_length );
+
+    if( status == PSA_SUCCESS )
+    {
+        operation->nonce_set = 1;
+    }
+
+    return( status );
 }
 
 /* Declare the lengths of the message and additional data for multipart AEAD. */
@@ -3356,26 +3410,44 @@
                                    size_t ad_length,
                                    size_t plaintext_length )
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
     if( !operation->key_set || operation->lengths_set )
     {
         return( PSA_ERROR_BAD_STATE );
     }
 
-    return( psa_driver_wrapper_aead_set_lengths( operation, ad_length,
-                                                 plaintext_length ) );
+    status = psa_driver_wrapper_aead_set_lengths( operation, ad_length,
+                                                  plaintext_length );
+
+    if( status == PSA_SUCCESS )
+    {
+        operation->lengths_set = 1;
+    }
+
+    return status;
 }
  /* Pass additional data to an active multipart AEAD operation. */
 psa_status_t psa_aead_update_ad( psa_aead_operation_t *operation,
                                  const uint8_t *input,
                                  size_t input_length )
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
     if( !operation->nonce_set || !operation->key_set )
     {
         return( PSA_ERROR_BAD_STATE );
     }
 
-    return( psa_driver_wrapper_aead_update_ad( operation, input,
-                                               input_length ) );
+    status = psa_driver_wrapper_aead_update_ad( operation, input,
+                                                input_length );
+
+    if( status == PSA_SUCCESS )
+    {
+        operation->ad_started = 1;
+    }
+
+    return status;
 }
 
 /* Encrypt or decrypt a message fragment in an active multipart AEAD
@@ -3387,6 +3459,7 @@
                               size_t output_size,
                               size_t *output_length )
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
     *output_length = 0;
 
@@ -3395,9 +3468,16 @@
         return( PSA_ERROR_BAD_STATE );
     }
 
-    return( psa_driver_wrapper_aead_update( operation, input, input_length,
-                                            output, output_size,
-                                            output_length ) );
+    status = psa_driver_wrapper_aead_update( operation, input, input_length,
+                                             output, output_size,
+                                             output_length );
+
+    if( status == PSA_SUCCESS )
+    {
+        operation->body_started = 1;
+    }
+
+    return status;
 }
 
 /* Finish encrypting a message in a multipart AEAD operation. */
@@ -3422,6 +3502,7 @@
                                             ciphertext_size,
                                             ciphertext_length,
                                             tag, tag_size, tag_length ) );
+
 }
 
 /* Finish authenticating and decrypting a message in a multipart AEAD
@@ -3466,7 +3547,6 @@
     operation->key_set = 0;
     operation->nonce_set = 0;
     operation->lengths_set = 0;
-    operation->is_encrypt = 0;
     operation->ad_started = 0;
     operation->body_started = 0;