Split multipart AEAD contexts into two parts
Split to data required for internal implementation and data required for
driver implementation with data left over for the PSA layer.
Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 4ab0c63..7190aa4 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3214,6 +3214,25 @@
return( status );
}
+/* Helper function to get the base algorithm from its variants. */
+static psa_algorithm_t psa_aead_get_base_algorithm(psa_algorithm_t alg)
+{
+ switch( PSA_ALG_AEAD_WITH_SHORTENED_TAG( alg, 0 ) )
+ {
+ case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 0 ):
+ return( PSA_ALG_CCM );
+
+ case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_GCM, 0 ):
+ return( PSA_ALG_GCM );
+
+ case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CHACHA20_POLY1305, 0 ):
+ return( PSA_ALG_CHACHA20_POLY1305 );
+
+ default:
+ return( PSA_ERROR_NOT_SUPPORTED );
+ }
+}
+
/* Set the key for a multipart authenticated encryption operation. */
psa_status_t psa_aead_encrypt_setup( psa_aead_operation_t *operation,
mbedtls_svc_key_id_t key,
@@ -3226,6 +3245,12 @@
if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
return( PSA_ERROR_NOT_SUPPORTED );
+ if( operation->key_set || operation->nonce_set ||
+ operation->ad_started || operation->body_started )
+ {
+ return( PSA_ERROR_BAD_STATE );
+ }
+
status = psa_get_and_lock_key_slot_with_policy(
key, &slot, PSA_KEY_USAGE_ENCRYPT, alg );
@@ -3242,6 +3267,7 @@
&attributes, slot->key.data,
slot->key.bytes, alg );
+ operation->key_type = psa_get_key_type( &attributes );
unlock_status = psa_unlock_key_slot( slot );
@@ -3250,6 +3276,12 @@
return( unlock_status );
}
+ if( status == PSA_SUCCESS )
+ {
+ operation->alg = psa_aead_get_base_algorithm( alg );
+ operation->key_set = 1;
+ }
+
return( status );
}
@@ -3265,6 +3297,12 @@
if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
return( PSA_ERROR_NOT_SUPPORTED );
+ if( operation->key_set || operation->nonce_set ||
+ operation->ad_started || operation->body_started )
+ {
+ return( PSA_ERROR_BAD_STATE );
+ }
+
status = psa_get_and_lock_key_slot_with_policy(
key, &slot, PSA_KEY_USAGE_DECRYPT, alg );
@@ -3281,6 +3319,7 @@
&attributes, slot->key.data,
slot->key.bytes, alg );
+ operation->key_type = psa_get_key_type( &attributes );
unlock_status = psa_unlock_key_slot( slot );
@@ -3289,6 +3328,12 @@
return( unlock_status );
}
+ if( status == PSA_SUCCESS )
+ {
+ operation->alg = psa_aead_get_base_algorithm( alg );
+ operation->key_set = 1;
+ }
+
return( status );
}
@@ -3341,14 +3386,23 @@
const uint8_t *nonce,
size_t nonce_length )
{
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
if( !operation->key_set || operation->nonce_set ||
operation->ad_started || operation->body_started )
{
return( PSA_ERROR_BAD_STATE );
}
- return( psa_driver_wrapper_aead_set_nonce( operation, nonce,
- nonce_length ) );
+ status = psa_driver_wrapper_aead_set_nonce( operation, nonce,
+ nonce_length );
+
+ if( status == PSA_SUCCESS )
+ {
+ operation->nonce_set = 1;
+ }
+
+ return( status );
}
/* Declare the lengths of the message and additional data for multipart AEAD. */
@@ -3356,26 +3410,44 @@
size_t ad_length,
size_t plaintext_length )
{
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
if( !operation->key_set || operation->lengths_set )
{
return( PSA_ERROR_BAD_STATE );
}
- return( psa_driver_wrapper_aead_set_lengths( operation, ad_length,
- plaintext_length ) );
+ status = psa_driver_wrapper_aead_set_lengths( operation, ad_length,
+ plaintext_length );
+
+ if( status == PSA_SUCCESS )
+ {
+ operation->lengths_set = 1;
+ }
+
+ return status;
}
/* Pass additional data to an active multipart AEAD operation. */
psa_status_t psa_aead_update_ad( psa_aead_operation_t *operation,
const uint8_t *input,
size_t input_length )
{
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
if( !operation->nonce_set || !operation->key_set )
{
return( PSA_ERROR_BAD_STATE );
}
- return( psa_driver_wrapper_aead_update_ad( operation, input,
- input_length ) );
+ status = psa_driver_wrapper_aead_update_ad( operation, input,
+ input_length );
+
+ if( status == PSA_SUCCESS )
+ {
+ operation->ad_started = 1;
+ }
+
+ return status;
}
/* Encrypt or decrypt a message fragment in an active multipart AEAD
@@ -3387,6 +3459,7 @@
size_t output_size,
size_t *output_length )
{
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
*output_length = 0;
@@ -3395,9 +3468,16 @@
return( PSA_ERROR_BAD_STATE );
}
- return( psa_driver_wrapper_aead_update( operation, input, input_length,
- output, output_size,
- output_length ) );
+ status = psa_driver_wrapper_aead_update( operation, input, input_length,
+ output, output_size,
+ output_length );
+
+ if( status == PSA_SUCCESS )
+ {
+ operation->body_started = 1;
+ }
+
+ return status;
}
/* Finish encrypting a message in a multipart AEAD operation. */
@@ -3422,6 +3502,7 @@
ciphertext_size,
ciphertext_length,
tag, tag_size, tag_length ) );
+
}
/* Finish authenticating and decrypting a message in a multipart AEAD
@@ -3466,7 +3547,6 @@
operation->key_set = 0;
operation->nonce_set = 0;
operation->lengths_set = 0;
- operation->is_encrypt = 0;
operation->ad_started = 0;
operation->body_started = 0;