Apply review feedback

* Reworked the cipher context once again to be more robustly defined
* Removed redundant memset
* Unified behaviour on failure between driver and software in cipher_finish
* Cipher test driver setup function now also returns early when its status
  is overridden, like the other test driver functions
* Removed redundant test cases
* Added bad-order checking to verify the driver doesn't get called where
  the spec says it won't.

Signed-off-by: Steven Cooreman <steven.cooreman@silabs.com>
diff --git a/include/psa/crypto_struct.h b/include/psa/crypto_struct.h
index 861850c..a85a9bf 100644
--- a/include/psa/crypto_struct.h
+++ b/include/psa/crypto_struct.h
@@ -168,7 +168,7 @@
     unsigned int key_set : 1;
     unsigned int iv_required : 1;
     unsigned int iv_set : 1;
-    unsigned int driver_in_use : 1; /* Indicates a driver is used instead of software fallback. */
+    unsigned int mbedtls_in_use : 1; /* Indicates mbed TLS is handling the operation. */
     uint8_t iv_size;
     uint8_t block_size;
     union
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 6b25903..8383eae 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -4057,17 +4057,19 @@
 
     /* The requested algorithm must be one that can be processed by cipher. */
     if( ! PSA_ALG_IS_CIPHER( alg ) )
-    {
-        memset( operation, 0, sizeof( *operation ) );
         return( PSA_ERROR_INVALID_ARGUMENT );
-    }
 
-    /* Reset the operation members to their initial state, except for alg. The
-     * alg member is used as an indicator that psa_cipher_abort needs to free
-     * allocated resources, which doesn't happen until later. */
+    /* Fetch key material from key storage. */
+    status = psa_get_key_from_slot( handle, &slot, usage, alg );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    /* Initialize the operation struct members, except for alg. The alg member
+     * is used to indicate to psa_cipher_abort that there are resources to free,
+     * so we only set it after resources have been allocated/initialized. */
     operation->key_set = 0;
     operation->iv_set = 0;
-    operation->driver_in_use = 0;
+    operation->mbedtls_in_use = 0;
     operation->iv_size = 0;
     operation->block_size = 0;
     if( alg == PSA_ALG_ECB_NO_PADDING )
@@ -4075,11 +4077,6 @@
     else
         operation->iv_required = 1;
 
-    /* Fetch key material from key storage. */
-    status = psa_get_key_from_slot( handle, &slot, usage, alg );
-    if( status != PSA_SUCCESS )
-        goto exit;
-
     /* Try doing the operation through a driver before using software fallback. */
     if( cipher_operation == MBEDTLS_ENCRYPT )
         status = psa_driver_wrapper_cipher_encrypt_setup( &operation->ctx.driver,
@@ -4090,32 +4087,25 @@
                                                           slot,
                                                           alg );
 
+    if( status == PSA_SUCCESS )
+        /* Once the driver context is initialised, it needs to be freed using
+        * psa_cipher_abort. Indicate this through setting alg. */
+        operation->alg = alg;
+
     if( status != PSA_ERROR_NOT_SUPPORTED ||
         psa_key_lifetime_is_external( slot->attr.lifetime ) )
-    {
-        /* Indicate this operation is bound to a driver. When the driver setup
-         * succeeded, this indicates to the core to not call any mbedtls_
-         * functions for this operation (contexts are not interoperable).
-         * In case the drivers couldn't setup and there's no way to fallback,
-         * indicate to the core to not call mbedtls_cipher_free on an
-         * uninitialised mbed TLS cipher context. */
-        operation->driver_in_use = 1;
-
-        /* If the wrapper call succeeded, it allocated resources that need to be
-         * freed using psa_cipher_abort. Indicate this through setting alg. */
-        if( status == PSA_SUCCESS )
-            operation->alg = alg;
-
         goto exit;
-    }
 
     /* Proceed with initializing an mbed TLS cipher context if no driver is
      * available for the given algorithm & key. */
     mbedtls_cipher_init( &operation->ctx.cipher );
 
     /* Once the cipher context is initialised, it needs to be freed using
-     * psa_cipher_abort. Indicate this through setting alg. */
+     * psa_cipher_abort. Indicate there is something to be freed through setting
+     * alg, and indicate the operation is being done using mbedtls crypto through
+     * setting mbedtls_in_use. */
     operation->alg = alg;
+    operation->mbedtls_in_use = 1;
 
     key_bits = psa_get_key_slot_bits( slot );
     cipher_info = mbedtls_cipher_info_from_psa( alg, slot->attr.type, key_bits, NULL );
@@ -4224,7 +4214,7 @@
         return( PSA_ERROR_BAD_STATE );
     }
 
-    if( operation->driver_in_use == 1 )
+    if( operation->mbedtls_in_use == 0 )
     {
         status = psa_driver_wrapper_cipher_generate_iv( &operation->ctx.driver,
                                                         iv,
@@ -4268,7 +4258,7 @@
         return( PSA_ERROR_BAD_STATE );
     }
 
-    if( operation->driver_in_use == 1 )
+    if( operation->mbedtls_in_use == 0 )
     {
         status = psa_driver_wrapper_cipher_set_iv( &operation->ctx.driver,
                                                    iv,
@@ -4397,7 +4387,7 @@
         return( PSA_ERROR_BAD_STATE );
     }
 
-    if( operation->driver_in_use == 1 )
+    if( operation->mbedtls_in_use == 0 )
     {
         status = psa_driver_wrapper_cipher_update( &operation->ctx.driver,
                                                    input,
@@ -4459,7 +4449,6 @@
                                 size_t *output_length )
 {
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
-    int cipher_ret = MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE;
     uint8_t temp_output_buffer[MBEDTLS_MAX_BLOCK_LENGTH];
     if( operation->alg == 0 )
     {
@@ -4470,17 +4459,13 @@
         return( PSA_ERROR_BAD_STATE );
     }
 
-    if( operation->driver_in_use == 1 )
+    if( operation->mbedtls_in_use == 0 )
     {
         status = psa_driver_wrapper_cipher_finish( &operation->ctx.driver,
                                                    output,
                                                    output_size,
                                                    output_length );
-        if( status != PSA_SUCCESS )
-            goto error;
-
-        (void) psa_cipher_abort( operation );
-        return( status );
+        goto exit;
     }
 
     if( operation->ctx.cipher.unprocessed_len != 0 )
@@ -4490,18 +4475,16 @@
               operation->ctx.cipher.operation == MBEDTLS_ENCRYPT ) )
         {
             status = PSA_ERROR_INVALID_ARGUMENT;
-            goto error;
+            goto exit;
         }
     }
 
-    cipher_ret = mbedtls_cipher_finish( &operation->ctx.cipher,
-                                        temp_output_buffer,
-                                        output_length );
-    if( cipher_ret != 0 )
-    {
-        status = mbedtls_to_psa_error( cipher_ret );
-        goto error;
-    }
+    status = mbedtls_to_psa_error(
+        mbedtls_cipher_finish( &operation->ctx.cipher,
+                               temp_output_buffer,
+                               output_length ) );
+    if( status != PSA_SUCCESS )
+        goto exit;
 
     if( *output_length == 0 )
         ; /* Nothing to copy. Note that output may be NULL in this case. */
@@ -4510,22 +4493,24 @@
     else
     {
         status = PSA_ERROR_BUFFER_TOO_SMALL;
-        goto error;
+        goto exit;
     }
 
-    mbedtls_platform_zeroize( temp_output_buffer, sizeof( temp_output_buffer ) );
-    status = psa_cipher_abort( operation );
+exit:
+    if( operation->mbedtls_in_use == 1 )
+        mbedtls_platform_zeroize( temp_output_buffer, sizeof( temp_output_buffer ) );
 
-    return( status );
+    if( status == PSA_SUCCESS )
+        return( psa_cipher_abort( operation ) );
+    else
+    {
+        *output_length = 0;
 
-error:
+        mbedtls_platform_zeroize( temp_output_buffer, sizeof( temp_output_buffer ) );
+        (void) psa_cipher_abort( operation );
 
-    *output_length = 0;
-
-    mbedtls_platform_zeroize( temp_output_buffer, sizeof( temp_output_buffer ) );
-    (void) psa_cipher_abort( operation );
-
-    return( status );
+        return( status );
+    }
 }
 
 psa_status_t psa_cipher_abort( psa_cipher_operation_t *operation )
@@ -4543,7 +4528,7 @@
     if( ! PSA_ALG_IS_CIPHER( operation->alg ) )
         return( PSA_ERROR_BAD_STATE );
 
-    if( operation->driver_in_use == 1 )
+    if( operation->mbedtls_in_use == 0 )
         psa_driver_wrapper_cipher_abort( &operation->ctx.driver );
     else
         mbedtls_cipher_free( &operation->ctx.cipher );
@@ -4551,7 +4536,7 @@
     operation->alg = 0;
     operation->key_set = 0;
     operation->iv_set = 0;
-    operation->driver_in_use = 0;
+    operation->mbedtls_in_use = 0;
     operation->iv_size = 0;
     operation->block_size = 0;
     operation->iv_required = 0;
diff --git a/tests/src/drivers/cipher.c b/tests/src/drivers/cipher.c
index f9106d1..fa7c6a9 100644
--- a/tests/src/drivers/cipher.c
+++ b/tests/src/drivers/cipher.c
@@ -225,6 +225,10 @@
      * struct. */
     memset( operation, 0, sizeof( *operation ) );
 
+    /* Allow overriding return value for testing purposes */
+    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
+        return( test_driver_cipher_hooks.forced_status );
+
     /* Test driver supports AES-CTR only, to verify operation calls. */
     if( alg != PSA_ALG_CTR ||
         psa_get_key_type( attributes ) != PSA_KEY_TYPE_AES )
@@ -258,10 +262,6 @@
     operation->iv_required = 1;
     operation->key_set = 1;
 
-    /* Allow overriding return value for testing purposes */
-    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
-        mbedtls_cipher_free( &operation->cipher );
-
     return( test_driver_cipher_hooks.forced_status );
 }
 
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.data b/tests/suites/test_suite_psa_crypto_driver_wrappers.data
index 7b5d6bd..7abc256 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.data
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.data
@@ -56,14 +56,6 @@
 depends_on:MBEDTLS_AES_C:MBEDTLS_CIPHER_MODE_CTR
 cipher_encrypt:PSA_ALG_CTR:PSA_KEY_TYPE_AES:"2b7e151628aed2a6abf7158809cf4f3c":"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a":"6bc1bee22e409f96e93d7e11739317":"8f9408fe80a81d3e813da3c7b0b2bd":0:PSA_ERROR_NOT_SUPPORTED:PSA_SUCCESS
 
-PSA symmetric encrypt: AES-CTR, 16 bytes, fallback w/ fake
-depends_on:MBEDTLS_AES_C:MBEDTLS_CIPHER_MODE_CTR
-cipher_encrypt:PSA_ALG_CTR:PSA_KEY_TYPE_AES:"2b7e151628aed2a6abf7158809cf4f3c":"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a":"6bc1bee22e409f96e93d7e117393172a":"8f9408fe80a81d3e813da3c7b0b2bd32":1:PSA_ERROR_NOT_SUPPORTED:PSA_SUCCESS
-
-PSA symmetric encrypt: AES-CTR, 15 bytes, fallback w/ fake
-depends_on:MBEDTLS_AES_C:MBEDTLS_CIPHER_MODE_CTR
-cipher_encrypt:PSA_ALG_CTR:PSA_KEY_TYPE_AES:"2b7e151628aed2a6abf7158809cf4f3c":"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a":"6bc1bee22e409f96e93d7e11739317":"8f9408fe80a81d3e813da3c7b0b2bd":1:PSA_ERROR_NOT_SUPPORTED:PSA_SUCCESS
-
 PSA symmetric encrypt: AES-CTR, 16 bytes, fake
 depends_on:MBEDTLS_AES_C:MBEDTLS_CIPHER_MODE_CTR
 cipher_encrypt:PSA_ALG_CTR:PSA_KEY_TYPE_AES:"2b7e151628aed2a6abf7158809cf4f3c":"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a":"6bc1bee22e409f96e93d7e117393172a":"d07a6a6e2687feb2":1:PSA_SUCCESS:PSA_SUCCESS
@@ -80,10 +72,6 @@
 depends_on:MBEDTLS_AES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_MODE_CTR
 cipher_decrypt:PSA_ALG_CTR:PSA_KEY_TYPE_AES:"2b7e151628aed2a6abf7158809cf4f3c":"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a":"396ee84fb75fdbb5c2b13c7fe5a654aa":"dd3b5e5319b7591daab1e1a92687feb2":0:PSA_ERROR_NOT_SUPPORTED:PSA_SUCCESS
 
-PSA symmetric decrypt: AES-CTR, 16 bytes, fallback w/ fake
-depends_on:MBEDTLS_AES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_MODE_CTR
-cipher_decrypt:PSA_ALG_CTR:PSA_KEY_TYPE_AES:"2b7e151628aed2a6abf7158809cf4f3c":"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a":"396ee84fb75fdbb5c2b13c7fe5a654aa":"dd3b5e5319b7591daab1e1a92687feb2":1:PSA_ERROR_NOT_SUPPORTED:PSA_SUCCESS
-
 PSA symmetric decrypt: AES-CTR, 16 bytes, fake
 depends_on:MBEDTLS_AES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_MODE_CTR
 cipher_decrypt:PSA_ALG_CTR:PSA_KEY_TYPE_AES:"2b7e151628aed2a6abf7158809cf4f3c":"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a":"396ee84fb75fdbb5c2b13c7fe5a654aa":"d07a6a6e2687feb2":1:PSA_SUCCESS:PSA_SUCCESS
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index af0c7ee..951670d 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -558,6 +558,9 @@
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     test_driver_cipher_hooks = test_driver_cipher_hooks_init();
 
+    ASSERT_ALLOC( output, input->len + 16 );
+    output_buffer_size = input->len + 16;
+
     PSA_ASSERT( psa_crypto_init( ) );
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT );
@@ -574,6 +577,9 @@
     TEST_EQUAL( test_driver_cipher_hooks.hits, 1 );
     TEST_EQUAL( status, test_driver_cipher_hooks.forced_status );
     test_driver_cipher_hooks.hits = 0;
+    status = psa_cipher_set_iv( &operation, iv->x, iv->len );
+    TEST_EQUAL( status, PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( test_driver_cipher_hooks.hits, 0 );
 
     /* Test setup call failure, decrypt */
     status = psa_cipher_decrypt_setup( &operation,
@@ -582,6 +588,9 @@
     TEST_EQUAL( test_driver_cipher_hooks.hits, 1 );
     TEST_EQUAL( status, test_driver_cipher_hooks.forced_status );
     test_driver_cipher_hooks.hits = 0;
+    status = psa_cipher_set_iv( &operation, iv->x, iv->len );
+    TEST_EQUAL( status, PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( test_driver_cipher_hooks.hits, 0 );
 
     /* Test IV setting failure */
     test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
@@ -596,8 +605,15 @@
     /* When setting the IV fails, it should call abort too */
     TEST_EQUAL( test_driver_cipher_hooks.hits, 2 );
     TEST_EQUAL( status, test_driver_cipher_hooks.forced_status );
-    psa_cipher_abort( &operation );
+    /* Failure should prevent further operations from executing on the driver */
     test_driver_cipher_hooks.hits = 0;
+    status = psa_cipher_update( &operation,
+                                input->x, input->len,
+                                output, output_buffer_size,
+                                &function_output_length );
+    TEST_EQUAL( status, PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( test_driver_cipher_hooks.hits, 0 );
+    psa_cipher_abort( &operation );
 
     /* Test IV generation failure */
     test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
@@ -608,15 +624,19 @@
     test_driver_cipher_hooks.hits = 0;
 
     test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
-    ASSERT_ALLOC( output, 16 );
-    status = psa_cipher_generate_iv( &operation, output, 16, &output_buffer_size );
-    /* When setting the IV fails, it should call abort too */
+    status = psa_cipher_generate_iv( &operation, output, 16, &function_output_length );
+    /* When generating the IV fails, it should call abort too */
     TEST_EQUAL( test_driver_cipher_hooks.hits, 2 );
     TEST_EQUAL( status, test_driver_cipher_hooks.forced_status );
-    mbedtls_free( output );
-    output = NULL;
-    psa_cipher_abort( &operation );
+    /* Failure should prevent further operations from executing on the driver */
     test_driver_cipher_hooks.hits = 0;
+    status = psa_cipher_update( &operation,
+                                input->x, input->len,
+                                output, output_buffer_size,
+                                &function_output_length );
+    TEST_EQUAL( status, PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( test_driver_cipher_hooks.hits, 0 );
+    psa_cipher_abort( &operation );
 
     /* Test update failure */
     test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
@@ -632,8 +652,6 @@
     test_driver_cipher_hooks.hits = 0;
 
     test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
-    ASSERT_ALLOC( output, input->len + 16 );
-    output_buffer_size = input->len + 16;
     status = psa_cipher_update( &operation,
                                 input->x, input->len,
                                 output, output_buffer_size,
@@ -641,10 +659,15 @@
     /* When the update call fails, it should call abort too */
     TEST_EQUAL( test_driver_cipher_hooks.hits, 2 );
     TEST_EQUAL( status, test_driver_cipher_hooks.forced_status );
-    mbedtls_free( output );
-    output = NULL;
-    psa_cipher_abort( &operation );
+    /* Failure should prevent further operations from executing on the driver */
     test_driver_cipher_hooks.hits = 0;
+    status = psa_cipher_update( &operation,
+                                input->x, input->len,
+                                output, output_buffer_size,
+                                &function_output_length );
+    TEST_EQUAL( status, PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( test_driver_cipher_hooks.hits, 0 );
+    psa_cipher_abort( &operation );
 
     /* Test finish failure */
     test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
@@ -659,8 +682,6 @@
     TEST_EQUAL( status, test_driver_cipher_hooks.forced_status );
     test_driver_cipher_hooks.hits = 0;
 
-    ASSERT_ALLOC( output, input->len + 16 );
-    output_buffer_size = input->len + 16;
     status = psa_cipher_update( &operation,
                                 input->x, input->len,
                                 output, output_buffer_size,
@@ -677,10 +698,15 @@
     /* When the finish call fails, it should call abort too */
     TEST_EQUAL( test_driver_cipher_hooks.hits, 2 );
     TEST_EQUAL( status, test_driver_cipher_hooks.forced_status );
-    mbedtls_free( output );
-    output = NULL;
-    psa_cipher_abort( &operation );
+    /* Failure should prevent further operations from executing on the driver */
     test_driver_cipher_hooks.hits = 0;
+    status = psa_cipher_update( &operation,
+                                input->x, input->len,
+                                output, output_buffer_size,
+                                &function_output_length );
+    TEST_EQUAL( status, PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( test_driver_cipher_hooks.hits, 0 );
+    psa_cipher_abort( &operation );
 
 exit:
     psa_cipher_abort( &operation );