Proper multipart AEAD GCM Implementation

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/library/psa_crypto_aead.c b/library/psa_crypto_aead.c
index 0e7ca63..1491b35 100644
--- a/library/psa_crypto_aead.c
+++ b/library/psa_crypto_aead.c
@@ -400,16 +400,12 @@
     #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM)
     if( operation->alg == PSA_ALG_GCM )
     {
-        operation->nonce = mbedtls_calloc( 1, nonce_length );
-
-        if( operation->nonce == NULL )
-            return( PSA_ERROR_INSUFFICIENT_MEMORY );
-
-        /* GCM sets nonce once additional data has been supplied */
-        memcpy( operation->nonce, nonce, nonce_length );
-
-        operation->nonce_length = nonce_length;
-        status = PSA_SUCCESS;
+        status = mbedtls_to_psa_error(
+                 mbedtls_gcm_starts( &operation->ctx.gcm,
+                                     operation->is_encrypt ?
+                                     MBEDTLS_GCM_ENCRYPT : MBEDTLS_GCM_DECRYPT,
+                                     nonce,
+                                     nonce_length ) );
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */
@@ -498,22 +494,8 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM)
     if( operation->alg == PSA_ALG_GCM )
     {
-         /* GCM currently requires all the additional data to be passed in
-          * in one contiguous buffer, so until that is re-done, we have to
-          * enforce this, as we cannot allocate a buffer to collate multiple
-          * calls into. */
-        if( operation->ad_started )
-            return( PSA_ERROR_NOT_SUPPORTED );
-
         status = mbedtls_to_psa_error(
-           mbedtls_gcm_starts( &operation->ctx.gcm,
-                               operation->is_encrypt ?
-                               MBEDTLS_GCM_ENCRYPT : MBEDTLS_GCM_DECRYPT,
-                               operation->nonce,
-                               operation->nonce_length,
-                               input,
-                               input_length ) );
-
+            mbedtls_gcm_update_ad( &operation->ctx.gcm, input, input_length ) );
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */
@@ -534,9 +516,6 @@
         return ( PSA_ERROR_NOT_SUPPORTED );
     }
 
-    if( status == PSA_SUCCESS )
-        operation->ad_started = 1;
-
     return ( status );
 }
 
@@ -562,18 +541,11 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM)
     if( operation->alg == PSA_ALG_GCM )
     {
-        /* For the time being set the requirement that all of the body data
-         * must be passed in in one update, rather than deal with the complexity
-         * of non block size aligned updates. This will be fixed in 3.0 when
-           we can change the signature of the GCM multipart functions */
-        if( operation->body_started )
-            return( PSA_ERROR_NOT_SUPPORTED );
-
-
-        status =  mbedtls_to_psa_error( mbedtls_gcm_update( &operation->ctx.gcm,
-                                                        input_length,
-                                                        input,
-                                                        output ) );
+        status =  mbedtls_to_psa_error(
+            mbedtls_gcm_update( &operation->ctx.gcm,
+                                input, input_length,
+                                output, output_size,
+                                &update_output_length ) );
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */
@@ -596,10 +568,7 @@
     }
 
     if( status == PSA_SUCCESS )
-    {
         *output_length = update_output_length;
-        operation->body_started = 1;
-    }
 
     return( status );
 }
@@ -647,17 +616,17 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM)
     if( operation->alg == PSA_ALG_GCM )
-        /* We will need to do final GCM pass in here when multipart is done. */
-        status =  mbedtls_to_psa_error( mbedtls_gcm_finish( &operation->ctx.gcm,
-                                                            tag,
-                                                            tag_size ) );
+        status =  mbedtls_to_psa_error(
+            mbedtls_gcm_finish( &operation->ctx.gcm,
+                                ciphertext, ciphertext_size,
+                                tag, tag_size ) );
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305)
     if( operation->alg == PSA_ALG_CHACHA20_POLY1305 )
         status = mbedtls_to_psa_error(
-           mbedtls_chachapoly_finish( &operation->ctx.chachapoly,
-                                      tag ) );
+            mbedtls_chachapoly_finish( &operation->ctx.chachapoly,
+                                       tag ) );
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305 */
     {
@@ -706,8 +675,8 @@
         /* Call finish to get the tag for comparison */
         status =  mbedtls_to_psa_error(
            mbedtls_gcm_finish( &operation->ctx.gcm,
-                               check_tag,
-                               operation->tag_length ) );
+                               plaintext, plaintext_size,
+                               check_tag, operation->tag_length ) );
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305)
@@ -765,15 +734,6 @@
     }
 
     operation->is_encrypt = 0;
-    operation->ad_started = 0;
-    operation->body_started = 0;
-
-    mbedtls_free( operation->tag_buffer );
-    operation->tag_buffer = NULL;
-
-    mbedtls_free( operation->nonce );
-    operation->nonce = NULL;
-    operation->nonce_length = 0;
 
     return( PSA_SUCCESS );
 }