Move CCM ouput to update step.

Move CCM to update all data at update step, as final step can only
output at most a block length, so outputting all data at this step
significantly breaks the tests. Had to add unpleasant workaround for the
validate stage, but this is the only way I can do things without
breaking CCM Alt implementations.

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/include/psa/crypto_struct.h b/include/psa/crypto_struct.h
index 6f0fc01..90a0c20 100644
--- a/include/psa/crypto_struct.h
+++ b/include/psa/crypto_struct.h
@@ -179,11 +179,13 @@
 
     /* Buffers for AD/data - only required until CCM gets proper multipart
        support. */
-    uint8_t* ad_buffer;
+    uint8_t *ad_buffer;
     size_t ad_length;
 
-    uint8_t* data_buffer;
-    size_t data_length;
+    uint8_t *body_buffer;
+    uint8_t body_length;
+
+    uint8_t *tag_buffer;
 
     /* buffer to store Nonce - only required until CCM and GCM get proper
        multipart support. */
@@ -205,7 +207,7 @@
     } ctx;
 };
 
-#define PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {0}, {0}}
+#define PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {0}, {0}}
 static inline struct psa_aead_operation_s psa_aead_operation_init( void )
 {
     const struct psa_aead_operation_s v = PSA_AEAD_OPERATION_INIT;
diff --git a/library/psa_crypto_aead.c b/library/psa_crypto_aead.c
index b559f7a..bfa271b 100644
--- a/library/psa_crypto_aead.c
+++ b/library/psa_crypto_aead.c
@@ -613,18 +613,9 @@
 {
     size_t update_output_size;
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM)
-    if( operation->alg == PSA_ALG_CCM )
-    {
-        /* CCM will currently not output anything until finish. */
-        update_output_size = 0;
-    }
-    else
-#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_CCM) */
-    {
-        update_output_size = input_length;
-    }
+    update_output_size = input_length;
 
     if( PSA_AEAD_UPDATE_OUTPUT_SIZE( operation->key_type, operation->alg,
                                         input_length ) > output_size )
@@ -678,27 +669,78 @@
     if( operation->alg == PSA_ALG_CCM )
     {
         /* CCM dooes not support multipart yet, so all the input has to be
-           passed in in one go. Store the data for the final step.*/
+           passed in in one go. */
         if( operation->body_started )
         {
             return( PSA_ERROR_BAD_STATE );
         }
 
-        /* Save the additional data for later, this will be passed in
-           when we have the body. */
-        operation->data_buffer = ( uint8_t * ) mbedtls_calloc(1, input_length );
+        /* Need to store tag for Finish() / Verify() */
+        operation->tag_buffer = ( uint8_t * ) mbedtls_calloc(1, operation->tag_length );
 
-        if( operation->data_buffer )
+        if( operation->tag_buffer )
         {
-            memcpy( operation->data_buffer, input, input_length );
-            operation->data_length = input_length;
-            status = PSA_SUCCESS;
+
+            if( operation->is_encrypt )
+            {
+                /* Perform oneshot CCM encryption with additional data already
+                   stored, as CCM does not support multipart yet.*/
+                status = mbedtls_to_psa_error( mbedtls_ccm_encrypt_and_tag( &operation->ctx.ccm,
+                                                                            input_length,
+                                                                            operation->nonce,
+                                                                            operation->nonce_length,
+                                                                            operation->ad_buffer,
+                                                                            operation->ad_length,
+                                                                            input,
+                                                                            output,
+                                                                            operation->tag_buffer,
+                                                                            operation->tag_length ) );
+
+                /* Even if the above operation fails, we no longer need the
+                   additional data.*/
+                mbedtls_free(operation->ad_buffer);
+                operation->ad_buffer = NULL;
+                operation->ad_length = 0;
+            }
+            else
+            {
+                /* Need to back up the body data so we can do this again
+                   later.*/
+                operation->body_buffer = ( uint8_t * ) mbedtls_calloc(1, input_length );
+
+                if( operation->body_buffer )
+                {
+                    memcpy( operation->body_buffer, input, input_length );
+                    operation->body_length = input_length;
+
+                    /* this will fail, as the tag is clearly false, but will write the
+                       decrypted data to the output buffer. */
+                    ret = mbedtls_ccm_auth_decrypt( &operation->ctx.ccm, input_length,
+                                                    operation->nonce, operation->nonce_length,
+                                                    operation->ad_buffer, operation->ad_length,
+                                                    input, output,
+                                                    operation->tag_buffer,
+                                                    operation->tag_length );
+
+                    if( ret == MBEDTLS_ERR_CCM_AUTH_FAILED )
+                    {
+                        status = PSA_SUCCESS;
+                    }
+                    else
+                    {
+                        status = mbedtls_to_psa_error( ret );
+                    }
+                }
+                else
+                {
+                    status = PSA_ERROR_INSUFFICIENT_MEMORY;
+                }
+            }
         }
         else
         {
-            return ( PSA_ERROR_INSUFFICIENT_MEMORY );
+            status = PSA_ERROR_INSUFFICIENT_MEMORY;
         }
-
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */
@@ -732,10 +774,10 @@
    mbedtls_psa_aead_verify() */
 static psa_status_t mbedtls_psa_aead_finish_checks( psa_aead_operation_t *operation,
                                                     size_t output_size,
-                                                    size_t tag_size,
-                                                    size_t *finish_output_size,
-                                                    size_t *output_tag_length )
+                                                    size_t tag_size )
 {
+    size_t finish_output_size;
+
     if( operation->lengths_set )
     {
         if( operation->ad_remaining != 0 || operation->body_remaining != 0 )
@@ -744,41 +786,28 @@
         }
     }
 
-    *output_tag_length = operation->tag_length;
-
-    if( tag_size < *output_tag_length)
+    if( tag_size < operation->tag_length )
     {
         return ( PSA_ERROR_BUFFER_TOO_SMALL );
     }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM)
-    if( operation->alg == PSA_ALG_CCM )
+    if( operation->is_encrypt )
     {
-        /* CCM will output all data at this step. */
-        *finish_output_size = operation->data_length;
+            finish_output_size = PSA_AEAD_FINISH_OUTPUT_SIZE( operation->key_type,
+                                                              operation->alg );
     }
     else
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */
     {
-        if( operation->is_encrypt )
-        {
-           *finish_output_size = PSA_AEAD_FINISH_OUTPUT_SIZE( operation->key_type,
-                                                              operation->alg );
-        }
-        else
-        {
-            *finish_output_size = PSA_AEAD_VERIFY_OUTPUT_SIZE( operation->key_type,
+            finish_output_size = PSA_AEAD_VERIFY_OUTPUT_SIZE( operation->key_type,
                                                                operation->alg );
-        }
     }
 
-    if( output_size < *finish_output_size )
+    if( output_size < finish_output_size )
     {
         return ( PSA_ERROR_BUFFER_TOO_SMALL );
     }
 
     return ( PSA_SUCCESS );
-
 }
 
 /* Finish encrypting a message in a multipart AEAD operation. */
@@ -791,11 +820,9 @@
                                       size_t *tag_length )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    size_t output_tag_length;
-    size_t finish_output_size;
+    size_t finish_output_size = 0;
 
-    status = mbedtls_psa_aead_finish_checks( operation, ciphertext_size, tag_size, &finish_output_size,
-                                             &output_tag_length);
+    status = mbedtls_psa_aead_finish_checks( operation, ciphertext_size, tag_size );
 
     if( status != PSA_SUCCESS )
     {
@@ -815,31 +842,13 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM)
     if( operation->alg == PSA_ALG_CCM )
     {
-        if( !operation->ad_buffer || !operation->data_buffer )
-        {
-            return( PSA_ERROR_BAD_STATE );
-        }
+        /* Copy the previously generated tag into place */
+        memcpy( tag, operation->tag_buffer, operation->tag_length );
 
-        /* Perform oneshot CCM encryption with data already stored, as
-           CCM does not support multipart yet.*/
-        status = mbedtls_to_psa_error( mbedtls_ccm_encrypt_and_tag( &operation->ctx.ccm,
-                                                                    operation->data_length,
-                                                                    operation->nonce,
-                                                                    operation->nonce_length,
-                                                                    operation->ad_buffer,
-                                                                    operation->ad_length,
-                                                                    operation->data_buffer,
-                                                                    ciphertext,
-                                                                    tag, tag_size ) );
+        mbedtls_free(operation->tag_buffer);
+        operation->tag_buffer = NULL;
 
-        /* Even if the above operation fails, we no longer need the data */
-        mbedtls_free(operation->ad_buffer);
-        operation->ad_buffer = NULL;
-        operation->ad_length = 0;
-
-        mbedtls_free(operation->data_buffer);
-        operation->data_buffer = NULL;
-        operation->data_length = 0;
+        status = PSA_SUCCESS;
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */
@@ -865,7 +874,7 @@
     if( status == PSA_SUCCESS )
     {
         *ciphertext_length = finish_output_size;
-        *tag_length = output_tag_length;
+        *tag_length = operation->tag_length;
     }
 
     mbedtls_psa_aead_abort(operation);
@@ -885,14 +894,15 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    size_t finish_output_size;
-    size_t output_tag_length;
+    uint8_t * temp_buffer;
+    size_t temp_buffer_size;
+
+    size_t finish_output_size = 0;
 
     int do_tag_check = 1;
     uint8_t check_tag[16];
 
-    status = mbedtls_psa_aead_finish_checks( operation, plaintext_size, tag_length, &finish_output_size,
-                                             &output_tag_length);
+    status = mbedtls_psa_aead_finish_checks( operation, plaintext_size, tag_length );
 
     if( status != PSA_SUCCESS )
     {
@@ -905,45 +915,58 @@
         /* Call finish to get the tag for comparison */
         status =  mbedtls_to_psa_error( mbedtls_gcm_finish( &operation->ctx.gcm,
                                                             check_tag,
-                                                            16 ) );
+                                                            operation->tag_length ) );
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM)
     if( operation->alg == PSA_ALG_CCM )
     {
-        if( !operation->ad_buffer || !operation->data_buffer )
+        if( !operation->ad_buffer || !operation->body_buffer )
         {
             return( PSA_ERROR_BAD_STATE );
         }
 
-        /* Perform oneshot CCM decryption with data already stored, as
-           CCM does not support multipart yet.*/
+        /* Perform oneshot CCM decryption *again*, as its the
+         * only way to get the tag, but this time throw away the
+           results, as verify cannot write that much data. */
+        temp_buffer_size = PSA_AEAD_UPDATE_OUTPUT_SIZE( operation->key_type,
+                                                        operation->alg, operation->body_length );
 
-        ret = mbedtls_ccm_auth_decrypt( &operation->ctx.ccm, operation->data_length,
-                                       operation->nonce, operation->nonce_length,
-                                       operation->ad_buffer, operation->ad_length,
-                                       operation->data_buffer, plaintext,
-                                       tag, tag_length );
+        temp_buffer = ( uint8_t * ) mbedtls_calloc(1, temp_buffer_size );
 
-        if( ret == MBEDTLS_ERR_CCM_AUTH_FAILED )
+        if( temp_buffer )
         {
-            status = PSA_ERROR_INVALID_SIGNATURE;
+            ret = mbedtls_ccm_auth_decrypt( &operation->ctx.ccm, operation->body_length,
+                                           operation->nonce, operation->nonce_length,
+                                           operation->ad_buffer, operation->ad_length,
+                                           operation->body_buffer, temp_buffer,
+                                           tag, tag_length );
+
+            if( ret == MBEDTLS_ERR_CCM_AUTH_FAILED )
+            {
+                status = PSA_ERROR_INVALID_SIGNATURE;
+            }
+            else
+            {
+                status = mbedtls_to_psa_error( ret );
+                do_tag_check = 0;
+            }
         }
         else
         {
-            status = mbedtls_to_psa_error( ret );
-            do_tag_check = 0;
+            status = PSA_ERROR_INSUFFICIENT_MEMORY;
         }
 
         /* Even if the above operation fails, we no longer need the data */
-        mbedtls_free(operation->ad_buffer);
-        operation->ad_buffer = NULL;
-        operation->ad_length = 0;
+        mbedtls_free(temp_buffer);
 
-        mbedtls_free(operation->data_buffer);
-        operation->data_buffer = NULL;
-        operation->data_length = 0;
+        mbedtls_free(operation->body_buffer);
+        operation->body_buffer = NULL;
+        operation->body_length = 0;
+
+        mbedtls_free(operation->tag_buffer);
+        operation->tag_buffer = NULL;
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */
@@ -953,6 +976,7 @@
         // call finish to get the tag for comparison.
         status = mbedtls_to_psa_error( mbedtls_chachapoly_finish( &operation->ctx.chachapoly,
                                                                   check_tag ) );
+
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305 */
@@ -1003,6 +1027,17 @@
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305 */
     }
 
+    mbedtls_free(operation->ad_buffer);
+    operation->ad_buffer = NULL;
+    operation->ad_length = 0;
+
+    mbedtls_free(operation->body_buffer);
+    operation->body_buffer = NULL;
+    operation->body_length = 0;
+
+    mbedtls_free(operation->tag_buffer);
+    operation->tag_buffer = NULL;
+
     return( PSA_SUCCESS );
 }