psa: cipher: Pass Mbed TLS implementation its operation ctx

As per drivers, pass to the Mbed TLS implementation of
the cipher multi-part operation its operation context
and not the PSA operation context.

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/psa_crypto_cipher.c b/library/psa_crypto_cipher.c
index a2b2942..e86aa95 100644
--- a/library/psa_crypto_cipher.c
+++ b/library/psa_crypto_cipher.c
@@ -32,7 +32,7 @@
 #include <string.h>
 
 static psa_status_t cipher_setup(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg,
@@ -42,22 +42,21 @@
     size_t key_bits;
     const mbedtls_cipher_info_t *cipher_info = NULL;
     psa_key_type_t key_type = attributes->core.type;
-    mbedtls_psa_cipher_operation_t *mbedtls_ctx = &operation->ctx.mbedtls_ctx;
 
     (void)key_buffer_size;
 
     /* Proceed with initializing an mbed TLS cipher context if no driver is
      * available for the given algorithm & key. */
-    mbedtls_cipher_init( &mbedtls_ctx->cipher );
+    mbedtls_cipher_init( &operation->cipher );
 
-    mbedtls_ctx->alg = alg;
+    operation->alg = alg;
     key_bits = attributes->core.bits;
     cipher_info = mbedtls_cipher_info_from_psa( alg, key_type,
                                                 key_bits, NULL );
     if( cipher_info == NULL )
         return( PSA_ERROR_NOT_SUPPORTED );
 
-    ret = mbedtls_cipher_setup( &mbedtls_ctx->cipher, cipher_info );
+    ret = mbedtls_cipher_setup( &operation->cipher, cipher_info );
     if( ret != 0 )
         goto exit;
 
@@ -68,14 +67,14 @@
         uint8_t keys[24];
         memcpy( keys, key_buffer, 16 );
         memcpy( keys + 16, key_buffer, 8 );
-        ret = mbedtls_cipher_setkey( &mbedtls_ctx->cipher,
+        ret = mbedtls_cipher_setkey( &operation->cipher,
                                      keys,
                                      192, cipher_operation );
     }
     else
 #endif
     {
-        ret = mbedtls_cipher_setkey( &mbedtls_ctx->cipher, key_buffer,
+        ret = mbedtls_cipher_setkey( &operation->cipher, key_buffer,
                                      (int) key_bits, cipher_operation );
     }
     if( ret != 0 )
@@ -86,11 +85,11 @@
     switch( alg )
     {
         case PSA_ALG_CBC_NO_PADDING:
-            ret = mbedtls_cipher_set_padding_mode( &mbedtls_ctx->cipher,
+            ret = mbedtls_cipher_set_padding_mode( &operation->cipher,
                                                    MBEDTLS_PADDING_NONE );
             break;
         case PSA_ALG_CBC_PKCS7:
-            ret = mbedtls_cipher_set_padding_mode( &mbedtls_ctx->cipher,
+            ret = mbedtls_cipher_set_padding_mode( &operation->cipher,
                                                    MBEDTLS_PADDING_PKCS7 );
             break;
         default:
@@ -102,18 +101,18 @@
         goto exit;
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CBC_NO_PADDING || MBEDTLS_PSA_BUILTIN_ALG_CBC_PKCS7 */
 
-    mbedtls_ctx->block_size = ( PSA_ALG_IS_STREAM_CIPHER( alg ) ? 1 :
+    operation->block_size = ( PSA_ALG_IS_STREAM_CIPHER( alg ) ? 1 :
                               PSA_BLOCK_CIPHER_BLOCK_LENGTH( key_type ) );
     if( ( alg & PSA_ALG_CIPHER_FROM_BLOCK_FLAG ) != 0 &&
         alg != PSA_ALG_ECB_NO_PADDING )
     {
-        mbedtls_ctx->iv_size = PSA_BLOCK_CIPHER_BLOCK_LENGTH( key_type );
+        operation->iv_size = PSA_BLOCK_CIPHER_BLOCK_LENGTH( key_type );
     }
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_CHACHA20)
     else
     if( ( alg == PSA_ALG_STREAM_CIPHER ) &&
         ( key_type == PSA_KEY_TYPE_CHACHA20 ) )
-        mbedtls_ctx->iv_size = 12;
+        operation->iv_size = 12;
 #endif
 
 exit:
@@ -121,7 +120,7 @@
 }
 
 psa_status_t mbedtls_psa_cipher_encrypt_setup(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg )
@@ -132,7 +131,7 @@
 }
 
 psa_status_t mbedtls_psa_cipher_decrypt_setup(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg )
@@ -143,36 +142,33 @@
 }
 
 psa_status_t mbedtls_psa_cipher_generate_iv(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     uint8_t *iv, size_t iv_size, size_t *iv_length )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    mbedtls_psa_cipher_operation_t *mbedtls_ctx = &operation->ctx.mbedtls_ctx;
 
-    if( iv_size < mbedtls_ctx->iv_size )
+    if( iv_size < operation->iv_size )
         return( PSA_ERROR_BUFFER_TOO_SMALL );
 
     ret = mbedtls_psa_get_random( MBEDTLS_PSA_RANDOM_STATE,
-                                  iv, mbedtls_ctx->iv_size );
+                                  iv, operation->iv_size );
     if( ret != 0 )
         return( mbedtls_to_psa_error( ret ) );
 
-    *iv_length = mbedtls_ctx->iv_size;
+    *iv_length = operation->iv_size;
 
     return( mbedtls_psa_cipher_set_iv( operation, iv, *iv_length ) );
 }
 
-psa_status_t mbedtls_psa_cipher_set_iv( psa_cipher_operation_t *operation,
+psa_status_t mbedtls_psa_cipher_set_iv( mbedtls_psa_cipher_operation_t *operation,
                                         const uint8_t *iv,
                                         size_t iv_length )
 {
-    mbedtls_psa_cipher_operation_t *mbedtls_ctx = &operation->ctx.mbedtls_ctx;
-
-    if( iv_length != mbedtls_ctx->iv_size )
+    if( iv_length != operation->iv_size )
         return( PSA_ERROR_INVALID_ARGUMENT );
 
     return( mbedtls_to_psa_error(
-                mbedtls_cipher_set_iv( &mbedtls_ctx->cipher,
+                mbedtls_cipher_set_iv( &operation->cipher,
                                        iv, iv_length ) ) );
 }
 
@@ -264,7 +260,7 @@
     return( status );
 }
 
-psa_status_t mbedtls_psa_cipher_update( psa_cipher_operation_t *operation,
+psa_status_t mbedtls_psa_cipher_update( mbedtls_psa_cipher_operation_t *operation,
                                         const uint8_t *input,
                                         size_t input_length,
                                         uint8_t *output,
@@ -272,18 +268,17 @@
                                         size_t *output_length )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    mbedtls_psa_cipher_operation_t *mbedtls_ctx = &operation->ctx.mbedtls_ctx;
     size_t expected_output_size;
 
-    if( ! PSA_ALG_IS_STREAM_CIPHER( mbedtls_ctx->alg ) )
+    if( ! PSA_ALG_IS_STREAM_CIPHER( operation->alg ) )
     {
         /* Take the unprocessed partial block left over from previous
          * update calls, if any, plus the input to this call. Remove
          * the last partial block, if any. You get the data that will be
          * output in this call. */
         expected_output_size =
-            ( mbedtls_ctx->cipher.unprocessed_len + input_length )
-            / mbedtls_ctx->block_size * mbedtls_ctx->block_size;
+            ( operation->cipher.unprocessed_len + input_length )
+            / operation->block_size * operation->block_size;
     }
     else
     {
@@ -293,12 +288,12 @@
     if( output_size < expected_output_size )
         return( PSA_ERROR_BUFFER_TOO_SMALL );
 
-    if( mbedtls_ctx->alg == PSA_ALG_ECB_NO_PADDING )
+    if( operation->alg == PSA_ALG_ECB_NO_PADDING )
     {
         /* mbedtls_cipher_update has an API inconsistency: it will only
         * process a single block at a time in ECB mode. Abstract away that
         * inconsistency here to match the PSA API behaviour. */
-        status = psa_cipher_update_ecb( &mbedtls_ctx->cipher,
+        status = psa_cipher_update_ecb( &operation->cipher,
                                         input,
                                         input_length,
                                         output,
@@ -308,26 +303,25 @@
     else
     {
         status = mbedtls_to_psa_error(
-            mbedtls_cipher_update( &mbedtls_ctx->cipher, input,
+            mbedtls_cipher_update( &operation->cipher, input,
                                    input_length, output, output_length ) );
     }
 
     return( status );
 }
 
-psa_status_t mbedtls_psa_cipher_finish( psa_cipher_operation_t *operation,
+psa_status_t mbedtls_psa_cipher_finish( mbedtls_psa_cipher_operation_t *operation,
                                         uint8_t *output,
                                         size_t output_size,
                                         size_t *output_length )
 {
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
-    mbedtls_psa_cipher_operation_t *mbedtls_ctx = &operation->ctx.mbedtls_ctx;
     uint8_t temp_output_buffer[MBEDTLS_MAX_BLOCK_LENGTH];
 
-    if( mbedtls_ctx->cipher.unprocessed_len != 0 )
+    if( operation->cipher.unprocessed_len != 0 )
     {
-        if( mbedtls_ctx->alg == PSA_ALG_ECB_NO_PADDING ||
-            mbedtls_ctx->alg == PSA_ALG_CBC_NO_PADDING )
+        if( operation->alg == PSA_ALG_ECB_NO_PADDING ||
+            operation->alg == PSA_ALG_CBC_NO_PADDING )
         {
             status = PSA_ERROR_INVALID_ARGUMENT;
             goto exit;
@@ -335,7 +329,7 @@
     }
 
     status = mbedtls_to_psa_error(
-        mbedtls_cipher_finish( &mbedtls_ctx->cipher,
+        mbedtls_cipher_finish( &operation->cipher,
                                temp_output_buffer,
                                output_length ) );
     if( status != PSA_SUCCESS )
@@ -355,16 +349,14 @@
     return( status );
 }
 
-psa_status_t mbedtls_psa_cipher_abort( psa_cipher_operation_t *operation )
+psa_status_t mbedtls_psa_cipher_abort( mbedtls_psa_cipher_operation_t *operation )
 {
-    mbedtls_psa_cipher_operation_t *mbedtls_ctx = &operation->ctx.mbedtls_ctx;
-
     /* Sanity check (shouldn't happen: operation->alg should
      * always have been initialized to a valid value). */
-    if( ! PSA_ALG_IS_CIPHER( mbedtls_ctx->alg ) )
+    if( ! PSA_ALG_IS_CIPHER( operation->alg ) )
         return( PSA_ERROR_BAD_STATE );
 
-    mbedtls_cipher_free( &mbedtls_ctx->cipher );
+    mbedtls_cipher_free( &operation->cipher );
 
     return( PSA_SUCCESS );
 }
diff --git a/library/psa_crypto_cipher.h b/library/psa_crypto_cipher.h
index 3a58a81..127f18c 100644
--- a/library/psa_crypto_cipher.h
+++ b/library/psa_crypto_cipher.h
@@ -48,7 +48,7 @@
  * \retval #PSA_ERROR_CORRUPTION_DETECTED
  */
 psa_status_t mbedtls_psa_cipher_encrypt_setup(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg );
@@ -78,7 +78,7 @@
  * \retval #PSA_ERROR_CORRUPTION_DETECTED
  */
 psa_status_t mbedtls_psa_cipher_decrypt_setup(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg );
@@ -106,7 +106,7 @@
  * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
  */
 psa_status_t mbedtls_psa_cipher_generate_iv(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     uint8_t *iv, size_t iv_size, size_t *iv_length );
 
 /** Set the IV for a symmetric encryption or decryption operation.
@@ -130,7 +130,7 @@
  * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
  */
 psa_status_t mbedtls_psa_cipher_set_iv(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     const uint8_t *iv, size_t iv_length );
 
 /** Encrypt or decrypt a message fragment in an active cipher operation.
@@ -155,7 +155,7 @@
  * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
  */
 psa_status_t mbedtls_psa_cipher_update(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     const uint8_t *input, size_t input_length,
     uint8_t *output, size_t output_size, size_t *output_length );
 
@@ -186,7 +186,7 @@
  * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
  */
 psa_status_t mbedtls_psa_cipher_finish(
-    psa_cipher_operation_t *operation,
+    mbedtls_psa_cipher_operation_t *operation,
     uint8_t *output, size_t output_size, size_t *output_length );
 
 /** Abort a cipher operation.
@@ -204,6 +204,6 @@
  *
  * \retval #PSA_SUCCESS
  */
-psa_status_t mbedtls_psa_cipher_abort( psa_cipher_operation_t *operation );
+psa_status_t mbedtls_psa_cipher_abort( mbedtls_psa_cipher_operation_t *operation );
 
 #endif /* PSA_CRYPTO_CIPHER_H */
diff --git a/library/psa_crypto_driver_wrappers.c b/library/psa_crypto_driver_wrappers.c
index 7a9bc7e..af63fbf 100644
--- a/library/psa_crypto_driver_wrappers.c
+++ b/library/psa_crypto_driver_wrappers.c
@@ -756,13 +756,13 @@
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
             /* Fell through, meaning no accelerator supports this operation */
-            status = mbedtls_psa_cipher_encrypt_setup( operation,
+            status = mbedtls_psa_cipher_encrypt_setup( &operation->ctx.mbedtls_ctx,
                                                        attributes,
                                                        key_buffer,
                                                        key_buffer_size,
                                                        alg );
             if( status == PSA_SUCCESS )
-                 operation->id = PSA_CRYPTO_MBED_TLS_DRIVER_ID;
+                operation->id = PSA_CRYPTO_MBED_TLS_DRIVER_ID;
 
             return( status );
 
@@ -849,7 +849,7 @@
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
             /* Fell through, meaning no accelerator supports this operation */
-            status = mbedtls_psa_cipher_decrypt_setup( operation,
+            status = mbedtls_psa_cipher_decrypt_setup( &operation->ctx.mbedtls_ctx,
                                                        attributes,
                                                        key_buffer,
                                                        key_buffer_size,
@@ -905,7 +905,7 @@
     switch( operation->id )
     {
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
-            return( mbedtls_psa_cipher_generate_iv( operation,
+            return( mbedtls_psa_cipher_generate_iv( &operation->ctx.mbedtls_ctx,
                                                     iv,
                                                     iv_size,
                                                     iv_length ) );
@@ -939,7 +939,7 @@
     switch( operation->id )
     {
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
-            return( mbedtls_psa_cipher_set_iv( operation,
+            return( mbedtls_psa_cipher_set_iv( &operation->ctx.mbedtls_ctx,
                                                iv,
                                                iv_length ) );
 
@@ -972,7 +972,7 @@
     switch( operation->id )
     {
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
-            return( mbedtls_psa_cipher_update( operation,
+            return( mbedtls_psa_cipher_update( &operation->ctx.mbedtls_ctx,
                                                input,
                                                input_length,
                                                output,
@@ -1010,7 +1010,7 @@
     switch( operation->id )
     {
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
-            return( mbedtls_psa_cipher_finish( operation,
+            return( mbedtls_psa_cipher_finish( &operation->ctx.mbedtls_ctx,
                                                output,
                                                output_size,
                                                output_length ) );
@@ -1051,7 +1051,7 @@
     switch( operation->id )
     {
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
-            return( mbedtls_psa_cipher_abort( operation ) );
+            return( mbedtls_psa_cipher_abort( &operation->ctx.mbedtls_ctx ) );
 
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
 #if defined(PSA_CRYPTO_DRIVER_TEST)
diff --git a/tests/include/test/drivers/cipher.h b/tests/include/test/drivers/cipher.h
index 06efa98..a1eb512 100644
--- a/tests/include/test/drivers/cipher.h
+++ b/tests/include/test/drivers/cipher.h
@@ -31,7 +31,7 @@
 #include <psa/crypto.h>
 
 #include "mbedtls/cipher.h"
-typedef psa_cipher_operation_t test_transparent_cipher_operation_t;
+typedef mbedtls_psa_cipher_operation_t test_transparent_cipher_operation_t;
 
 typedef struct{
     unsigned int initialised : 1;