psa: cipher: Dispatch based on driver identifier

For cipher multi-part operations, dispatch based on
the driver identifier even in the case of the
Mbed TLS software implementation (viewed as a driver).
Also use the driver identifier to check that an
cipher operation context is active or not.

This aligns the way hash and cipher multi-part
operations are dispatched.

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/include/psa/crypto_struct.h b/include/psa/crypto_struct.h
index 3ccad24..491d952 100644
--- a/include/psa/crypto_struct.h
+++ b/include/psa/crypto_struct.h
@@ -73,11 +73,6 @@
 #include "psa/crypto_driver_contexts.h"
 
 typedef struct {
-    /** Unique ID indicating which driver got assigned to do the
-     * operation. Since driver contexts are driver-specific, swapping
-     * drivers halfway through the operation is not supported.
-     * ID values are auto-generated in psa_driver_wrappers.h */
-    unsigned int id;
     /** Context structure for the assigned driver, when id is not zero. */
     void* ctx;
 } psa_operation_driver_context_t;
@@ -143,10 +138,17 @@
 
 struct psa_cipher_operation_s
 {
+    /** Unique ID indicating which driver got assigned to do the
+     * operation. Since driver contexts are driver-specific, swapping
+     * drivers halfway through the operation is not supported.
+     * ID values are auto-generated in psa_crypto_driver_wrappers.h
+     * ID value zero means the context is not valid or not assigned to
+     * any driver (i.e. none of the driver contexts are active). */
+    unsigned int id;
+
     psa_algorithm_t alg;
     unsigned int iv_required : 1;
     unsigned int iv_set : 1;
-    unsigned int mbedtls_in_use : 1; /* Indicates mbed TLS is handling the operation. */
     uint8_t iv_size;
     uint8_t block_size;
     union
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 3dfee3b..f4d8a3e 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3393,7 +3393,7 @@
                               PSA_KEY_USAGE_DECRYPT );
 
     /* A context must be freshly initialized before it can be set up. */
-    if( operation->alg != 0 )
+    if( operation->id != 0 )
         return( PSA_ERROR_BAD_STATE );
 
     /* The requested algorithm must be one that can be processed by cipher. */
@@ -3405,11 +3405,12 @@
     if( status != PSA_SUCCESS )
         goto exit;
 
-    /* Initialize the operation struct members, except for alg. The alg member
+    /* Initialize the operation struct members, except for id. The id member
      * is used to indicate to psa_cipher_abort that there are resources to free,
-     * so we only set it after resources have been allocated/initialized. */
+     * so we only set it (in the driver wrapper) after resources have been
+     * allocated/initialized. */
+    operation->alg = alg;
     operation->iv_set = 0;
-    operation->mbedtls_in_use = 0;
     operation->iv_size = 0;
     operation->block_size = 0;
     if( alg == PSA_ALG_ECB_NO_PADDING )
@@ -3435,13 +3436,6 @@
                                                           slot->key.bytes,
                                                           alg );
 
-    if( status == PSA_SUCCESS )
-    {
-       /* Once the driver context is initialized, it needs to be freed using
-        * psa_cipher_abort. Indicate this through setting alg. */
-        operation->alg = alg;
-    }
-
 exit:
     if( status != PSA_SUCCESS )
         psa_cipher_abort( operation );
@@ -3472,7 +3466,7 @@
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
-    if( operation->alg == 0 )
+    if( operation->id == 0 )
     {
         return( PSA_ERROR_BAD_STATE );
     }
@@ -3501,7 +3495,7 @@
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
-    if( operation->alg == 0 )
+    if( operation->id == 0 )
     {
         return( PSA_ERROR_BAD_STATE );
     }
@@ -3531,7 +3525,7 @@
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
-    if( operation->alg == 0 )
+    if( operation->id == 0 )
     {
         return( PSA_ERROR_BAD_STATE );
     }
@@ -3559,7 +3553,7 @@
 {
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
 
-    if( operation->alg == 0 )
+    if( operation->id == 0 )
     {
         return( PSA_ERROR_BAD_STATE );
     }
@@ -3585,7 +3579,7 @@
 
 psa_status_t psa_cipher_abort( psa_cipher_operation_t *operation )
 {
-    if( operation->alg == 0 )
+    if( operation->id == 0 )
     {
         /* The object has (apparently) been initialized but it is not (yet)
          * in use. It's ok to call abort on such an object, and there's
@@ -3600,9 +3594,9 @@
 
     psa_driver_wrapper_cipher_abort( operation );
 
+    operation->id = 0;
     operation->alg = 0;
     operation->iv_set = 0;
-    operation->mbedtls_in_use = 0;
     operation->iv_size = 0;
     operation->block_size = 0;
     operation->iv_required = 0;
diff --git a/library/psa_crypto_cipher.c b/library/psa_crypto_cipher.c
index 91d471b..340f674 100644
--- a/library/psa_crypto_cipher.c
+++ b/library/psa_crypto_cipher.c
@@ -49,13 +49,7 @@
      * available for the given algorithm & key. */
     mbedtls_cipher_init( &operation->ctx.cipher );
 
-    /* Once the cipher context is initialised, it needs to be freed using
-     * psa_cipher_abort. Indicate there is something to be freed through setting
-     * alg, and indicate the operation is being done using mbedtls crypto through
-     * setting mbedtls_in_use. */
     operation->alg = alg;
-    operation->mbedtls_in_use = 1;
-
     key_bits = attributes->core.bits;
     cipher_info = mbedtls_cipher_info_from_psa( alg, key_type,
                                                 key_bits, NULL );
diff --git a/library/psa_crypto_driver_wrappers.c b/library/psa_crypto_driver_wrappers.c
index 883b944..7a9bc7e 100644
--- a/library/psa_crypto_driver_wrappers.c
+++ b/library/psa_crypto_driver_wrappers.c
@@ -741,8 +741,7 @@
             /* Declared with fallback == true */
             if( status == PSA_SUCCESS )
             {
-                operation->ctx.driver.id =
-                    PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID;
+                operation->id = PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID;
                 operation->ctx.driver.ctx = driver_ctx;
             }
             else
@@ -757,11 +756,15 @@
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
             /* Fell through, meaning no accelerator supports this operation */
-            return( mbedtls_psa_cipher_encrypt_setup( operation,
-                                                      attributes,
-                                                      key_buffer,
-                                                      key_buffer_size,
-                                                      alg ) );
+            status = mbedtls_psa_cipher_encrypt_setup( operation,
+                                                       attributes,
+                                                       key_buffer,
+                                                       key_buffer_size,
+                                                       alg );
+            if( status == PSA_SUCCESS )
+                 operation->id = PSA_CRYPTO_MBED_TLS_DRIVER_ID;
+
+            return( status );
 
         /* Add cases for opaque driver here */
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
@@ -779,7 +782,7 @@
                                                        alg );
             if( status == PSA_SUCCESS )
             {
-                operation->ctx.driver.id = PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID;
+                operation->id = PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID;
                 operation->ctx.driver.ctx = driver_ctx;
             }
             else
@@ -831,8 +834,7 @@
             /* Declared with fallback == true */
             if( status == PSA_SUCCESS )
             {
-                operation->ctx.driver.id =
-                    PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID;
+                operation->id = PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID;
                 operation->ctx.driver.ctx = driver_ctx;
             }
             else
@@ -847,11 +849,16 @@
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
             /* Fell through, meaning no accelerator supports this operation */
-            return( mbedtls_psa_cipher_decrypt_setup( operation,
-                                                      attributes,
-                                                      key_buffer,
-                                                      key_buffer_size,
-                                                      alg ) );
+            status = mbedtls_psa_cipher_decrypt_setup( operation,
+                                                       attributes,
+                                                       key_buffer,
+                                                       key_buffer_size,
+                                                       alg );
+            if( status == PSA_SUCCESS )
+                operation->id = PSA_CRYPTO_MBED_TLS_DRIVER_ID;
+
+            return( status );
+
         /* Add cases for opaque driver here */
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
 #if defined(PSA_CRYPTO_DRIVER_TEST)
@@ -868,7 +875,7 @@
                                                        alg );
             if( status == PSA_SUCCESS )
             {
-                operation->ctx.driver.id = PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID;
+                operation->id = PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID;
                 operation->ctx.driver.ctx = driver_ctx;
             }
             else
@@ -895,15 +902,14 @@
     size_t iv_size,
     size_t *iv_length )
 {
-    if( operation->mbedtls_in_use )
-        return( mbedtls_psa_cipher_generate_iv( operation,
-                                                iv,
-                                                iv_size,
-                                                iv_length ) );
-
-#if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
-    switch( operation->ctx.driver.id )
+    switch( operation->id )
     {
+        case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
+            return( mbedtls_psa_cipher_generate_iv( operation,
+                                                    iv,
+                                                    iv_size,
+                                                    iv_length ) );
+#if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
 #if defined(PSA_CRYPTO_DRIVER_TEST)
         case PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID:
             return( test_transparent_cipher_generate_iv(
@@ -911,9 +917,7 @@
                         iv,
                         iv_size,
                         iv_length ) );
-#endif /* PSA_CRYPTO_DRIVER_TEST */
 
-#if defined(PSA_CRYPTO_DRIVER_TEST)
         case PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID:
             return( test_opaque_cipher_generate_iv(
                         operation->ctx.driver.ctx,
@@ -921,8 +925,8 @@
                         iv_size,
                         iv_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
-    }
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
+    }
 
     return( PSA_ERROR_INVALID_ARGUMENT );
 }
@@ -932,28 +936,27 @@
     const uint8_t *iv,
     size_t iv_length )
 {
-    if( operation->mbedtls_in_use )
-        return( mbedtls_psa_cipher_set_iv( operation,
-                                           iv,
-                                           iv_length ) );
+    switch( operation->id )
+    {
+        case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
+            return( mbedtls_psa_cipher_set_iv( operation,
+                                               iv,
+                                               iv_length ) );
 
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
-    switch( operation->ctx.driver.id )
-    {
 #if defined(PSA_CRYPTO_DRIVER_TEST)
         case PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID:
             return( test_transparent_cipher_set_iv( operation->ctx.driver.ctx,
                                                     iv,
                                                     iv_length ) );
-#endif /* PSA_CRYPTO_DRIVER_TEST */
-#if defined(PSA_CRYPTO_DRIVER_TEST)
+
         case PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID:
             return( test_opaque_cipher_set_iv( operation->ctx.driver.ctx,
                                                iv,
                                                iv_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
-    }
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
+    }
 
     return( PSA_ERROR_INVALID_ARGUMENT );
 }
@@ -966,17 +969,16 @@
     size_t output_size,
     size_t *output_length )
 {
-    if( operation->mbedtls_in_use )
-        return( mbedtls_psa_cipher_update( operation,
-                                           input,
-                                           input_length,
-                                           output,
-                                           output_size,
-                                           output_length ) );
-
-#if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
-    switch( operation->ctx.driver.id )
+    switch( operation->id )
     {
+        case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
+            return( mbedtls_psa_cipher_update( operation,
+                                               input,
+                                               input_length,
+                                               output,
+                                               output_size,
+                                               output_length ) );
+#if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
 #if defined(PSA_CRYPTO_DRIVER_TEST)
         case PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID:
             return( test_transparent_cipher_update( operation->ctx.driver.ctx,
@@ -985,8 +987,6 @@
                                                     output,
                                                     output_size,
                                                     output_length ) );
-#endif /* PSA_CRYPTO_DRIVER_TEST */
-#if defined(PSA_CRYPTO_DRIVER_TEST)
         case PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID:
             return( test_opaque_cipher_update( operation->ctx.driver.ctx,
                                                input,
@@ -995,8 +995,8 @@
                                                output_size,
                                                output_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
-    }
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
+    }
 
     return( PSA_ERROR_INVALID_ARGUMENT );
 }
@@ -1007,31 +1007,31 @@
     size_t output_size,
     size_t *output_length )
 {
-    if( operation->mbedtls_in_use )
-        return( mbedtls_psa_cipher_finish( operation,
-                                           output,
-                                           output_size,
-                                           output_length ) );
+    switch( operation->id )
+    {
+        case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
+            return( mbedtls_psa_cipher_finish( operation,
+                                               output,
+                                               output_size,
+                                               output_length ) );
+
 
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
-    switch( operation->ctx.driver.id )
-    {
 #if defined(PSA_CRYPTO_DRIVER_TEST)
         case PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID:
             return( test_transparent_cipher_finish( operation->ctx.driver.ctx,
                                                     output,
                                                     output_size,
                                                     output_length ) );
-#endif /* PSA_CRYPTO_DRIVER_TEST */
-#if defined(PSA_CRYPTO_DRIVER_TEST)
+
         case PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID:
             return( test_opaque_cipher_finish( operation->ctx.driver.ctx,
                                                output,
                                                output_size,
                                                output_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
-    }
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
+    }
 
     return( PSA_ERROR_INVALID_ARGUMENT );
 }
@@ -1039,20 +1039,21 @@
 psa_status_t psa_driver_wrapper_cipher_abort(
     psa_cipher_operation_t *operation )
 {
-    if( operation->mbedtls_in_use )
-        return( mbedtls_psa_cipher_abort( operation ) );
-
-#if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_operation_driver_context_t *driver_context = &operation->ctx.driver;
 
     /* The object has (apparently) been initialized but it is not in use. It's
      * ok to call abort on such an object, and there's nothing to do. */
-    if( driver_context->ctx == NULL && driver_context->id == 0 )
+    if( ( operation->id != PSA_CRYPTO_MBED_TLS_DRIVER_ID ) &&
+        ( driver_context->ctx == NULL ) )
         return( PSA_SUCCESS );
 
-    switch( driver_context->id )
+    switch( operation->id )
     {
+        case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
+            return( mbedtls_psa_cipher_abort( operation ) );
+
+#if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
 #if defined(PSA_CRYPTO_DRIVER_TEST)
         case PSA_CRYPTO_TRANSPARENT_TEST_DRIVER_ID:
             status = test_transparent_cipher_abort( driver_context->ctx );
@@ -1061,11 +1062,9 @@
                 sizeof( test_transparent_cipher_operation_t ) );
             mbedtls_free( driver_context->ctx );
             driver_context->ctx = NULL;
-            driver_context->id = 0;
 
             return( status );
-#endif /* PSA_CRYPTO_DRIVER_TEST */
-#if defined(PSA_CRYPTO_DRIVER_TEST)
+
         case PSA_CRYPTO_OPAQUE_TEST_DRIVER_ID:
             status = test_opaque_cipher_abort( driver_context->ctx );
             mbedtls_platform_zeroize(
@@ -1073,13 +1072,13 @@
                 sizeof( test_opaque_cipher_operation_t ) );
             mbedtls_free( driver_context->ctx );
             driver_context->ctx = NULL;
-            driver_context->id = 0;
 
             return( status );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
-    }
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
+    }
 
+    (void)status;
     return( PSA_ERROR_INVALID_ARGUMENT );
 }