Update all uses of old AEAD output size macros

Signed-off-by: Bence Szépkúti <bence.szepkuti@arm.com>
diff --git a/library/psa_crypto_aead.c b/library/psa_crypto_aead.c
index 2632830..356679c 100644
--- a/library/psa_crypto_aead.c
+++ b/library/psa_crypto_aead.c
@@ -154,10 +154,14 @@
             return( PSA_ERROR_NOT_SUPPORTED );
     }
 
-    if( PSA_AEAD_TAG_LENGTH( alg ) > full_tag_length )
+    if( PSA_AEAD_TAG_LENGTH( attributes->core.type,
+                             key_bits, alg )
+        > full_tag_length )
         return( PSA_ERROR_INVALID_ARGUMENT );
 
-    operation->tag_length = PSA_AEAD_TAG_LENGTH( alg );
+    operation->tag_length = PSA_AEAD_TAG_LENGTH( attributes->core.type,
+                                                 key_bits,
+                                                 alg );
 
     return( PSA_SUCCESS );
 }
diff --git a/programs/psa/key_ladder_demo.c b/programs/psa/key_ladder_demo.c
index 47d5de6..5d64349 100644
--- a/programs/psa/key_ladder_demo.c
+++ b/programs/psa/key_ladder_demo.c
@@ -365,6 +365,8 @@
     psa_status_t status;
     FILE *input_file = NULL;
     FILE *output_file = NULL;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_type_t key_type;
     long input_position;
     size_t input_size;
     size_t buffer_size = 0;
@@ -385,7 +387,10 @@
     }
 #endif
     input_size = input_position;
-    buffer_size = PSA_AEAD_ENCRYPT_OUTPUT_SIZE( WRAPPING_ALG, input_size );
+    PSA_CHECK( psa_get_key_attributes( wrapping_key, &attributes ) );
+    key_type = psa_get_key_type( &attributes );
+    buffer_size =
+        PSA_AEAD_ENCRYPT_OUTPUT_SIZE( key_type, WRAPPING_ALG, input_size );
     /* Check for integer overflow. */
     if( buffer_size < input_size )
     {
@@ -442,6 +447,8 @@
     psa_status_t status;
     FILE *input_file = NULL;
     FILE *output_file = NULL;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_type_t key_type;
     unsigned char *buffer = NULL;
     size_t ciphertext_size = 0;
     size_t plaintext_size;
@@ -465,8 +472,10 @@
         status = DEMO_ERROR;
         goto exit;
     }
+    PSA_CHECK( psa_get_key_attributes( wrapping_key, &attributes) );
+    key_type = psa_get_key_type( &attributes);
     ciphertext_size =
-        PSA_AEAD_ENCRYPT_OUTPUT_SIZE( WRAPPING_ALG, header.payload_size );
+        PSA_AEAD_ENCRYPT_OUTPUT_SIZE( key_type, WRAPPING_ALG, header.payload_size );
     /* Check for integer overflow. */
     if( ciphertext_size < header.payload_size )
     {
diff --git a/programs/psa/psa_constant_names_generated.c b/programs/psa/psa_constant_names_generated.c
index 2175af9..dcbe87f 100644
--- a/programs/psa/psa_constant_names_generated.c
+++ b/programs/psa/psa_constant_names_generated.c
@@ -169,11 +169,11 @@
         } else if (alg & PSA_ALG_AEAD_AT_LEAST_THIS_LENGTH_FLAG) {
             append(&buffer, buffer_size, &required_size,
                    "PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(", 43);
-            length_modifier = PSA_AEAD_TAG_LENGTH(alg);
+            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
         } else if (core_alg != alg) {
             append(&buffer, buffer_size, &required_size,
                    "PSA_ALG_AEAD_WITH_SHORTENED_TAG(", 32);
-            length_modifier = PSA_AEAD_TAG_LENGTH(alg);
+            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
         }
     } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) &&
                !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) {
diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py
index ff07ecd..71afd02 100755
--- a/scripts/generate_psa_constants.py
+++ b/scripts/generate_psa_constants.py
@@ -117,11 +117,11 @@
         } else if (alg & PSA_ALG_AEAD_AT_LEAST_THIS_LENGTH_FLAG) {
             append(&buffer, buffer_size, &required_size,
                    "PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(", 43);
-            length_modifier = PSA_AEAD_TAG_LENGTH(alg);
+            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
         } else if (core_alg != alg) {
             append(&buffer, buffer_size, &required_size,
                    "PSA_ALG_AEAD_WITH_SHORTENED_TAG(", 32);
-            length_modifier = PSA_AEAD_TAG_LENGTH(alg);
+            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
         }
     } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) &&
                !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) {
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index 4e56800..310b2a7 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -2985,24 +2985,16 @@
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_type_t key_type = key_type_arg;
     psa_algorithm_t alg = alg_arg;
+    size_t key_bits;
     unsigned char *output_data = NULL;
     size_t output_size = 0;
     size_t output_length = 0;
     unsigned char *output_data2 = NULL;
     size_t output_length2 = 0;
-    size_t tag_length = PSA_AEAD_TAG_LENGTH( alg );
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
     psa_status_t expected_result = expected_result_arg;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
 
-    output_size = input_data->len + tag_length;
-    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
-     * should be exact. */
-    if( expected_result != PSA_ERROR_INVALID_ARGUMENT )
-        TEST_EQUAL( output_size,
-                    PSA_AEAD_ENCRYPT_OUTPUT_SIZE( alg, input_data->len ) );
-    ASSERT_ALLOC( output_data, output_size );
-
     PSA_ASSERT( psa_crypto_init( ) );
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT );
@@ -3011,6 +3003,22 @@
 
     PSA_ASSERT( psa_import_key( &attributes, key_data->x, key_data->len,
                                 &key ) );
+    PSA_ASSERT( psa_get_key_attributes( key, &attributes ) );
+    key_bits = psa_get_key_bits( &attributes );
+
+    output_size = input_data->len + PSA_AEAD_TAG_LENGTH( key_type, key_bits,
+                                                         alg );
+    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
+     * should be exact. */
+    if( expected_result != PSA_ERROR_INVALID_ARGUMENT &&
+        expected_result != PSA_ERROR_NOT_SUPPORTED )
+    {
+        TEST_EQUAL( output_size,
+                    PSA_AEAD_ENCRYPT_OUTPUT_SIZE( key_type, alg, input_data->len ) );
+        TEST_ASSERT( output_size <=
+                     PSA_AEAD_ENCRYPT_OUTPUT_MAX_SIZE( input_data->len ) );
+    }
+    ASSERT_ALLOC( output_data, output_size );
 
     status = psa_aead_encrypt( key, alg,
                                nonce->x, nonce->len,
@@ -3038,7 +3046,7 @@
         /* For all currently defined algorithms, PSA_AEAD_DECRYPT_OUTPUT_SIZE
          * should be exact. */
         TEST_EQUAL( input_data->len,
-                    PSA_AEAD_DECRYPT_OUTPUT_SIZE( alg, output_length ) );
+                    PSA_AEAD_DECRYPT_OUTPUT_SIZE( key_type, alg, output_length ) );
 
         TEST_ASSERT( input_data->len <=
                      PSA_AEAD_DECRYPT_OUTPUT_MAX_SIZE( output_length ) );
@@ -3075,22 +3083,13 @@
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_type_t key_type = key_type_arg;
     psa_algorithm_t alg = alg_arg;
+    size_t key_bits;
     unsigned char *output_data = NULL;
     size_t output_size = 0;
     size_t output_length = 0;
-    size_t tag_length = PSA_AEAD_TAG_LENGTH( alg );
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
 
-    output_size = input_data->len + tag_length;
-    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
-     * should be exact. */
-    TEST_EQUAL( output_size,
-                PSA_AEAD_ENCRYPT_OUTPUT_SIZE( alg, input_data->len ) );
-    TEST_ASSERT( output_size <=
-                 PSA_AEAD_ENCRYPT_OUTPUT_MAX_SIZE( input_data->len ) );
-    ASSERT_ALLOC( output_data, output_size );
-
     PSA_ASSERT( psa_crypto_init( ) );
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT  );
@@ -3099,6 +3098,18 @@
 
     PSA_ASSERT( psa_import_key( &attributes, key_data->x, key_data->len,
                                 &key ) );
+    PSA_ASSERT( psa_get_key_attributes( key, &attributes ) );
+    key_bits = psa_get_key_bits( &attributes );
+
+    output_size = input_data->len + PSA_AEAD_TAG_LENGTH( key_type, key_bits,
+                                                         alg );
+    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
+     * should be exact. */
+    TEST_EQUAL( output_size,
+                PSA_AEAD_ENCRYPT_OUTPUT_SIZE( key_type, alg, input_data->len ) );
+    TEST_ASSERT( output_size <=
+                 PSA_AEAD_ENCRYPT_OUTPUT_MAX_SIZE( input_data->len ) );
+    ASSERT_ALLOC( output_data, output_size );
 
     status = psa_aead_encrypt( key, alg,
                                nonce->x, nonce->len,
@@ -3139,26 +3150,14 @@
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_type_t key_type = key_type_arg;
     psa_algorithm_t alg = alg_arg;
+    size_t key_bits;
     unsigned char *output_data = NULL;
     size_t output_size = 0;
     size_t output_length = 0;
-    size_t tag_length = PSA_AEAD_TAG_LENGTH( alg );
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_status_t expected_result = expected_result_arg;
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
 
-    output_size = input_data->len - tag_length;
-    if( expected_result != PSA_ERROR_INVALID_ARGUMENT )
-    {
-        /* For all currently defined algorithms, PSA_AEAD_DECRYPT_OUTPUT_SIZE
-         * should be exact. */
-        TEST_EQUAL( output_size,
-                    PSA_AEAD_DECRYPT_OUTPUT_SIZE( alg, input_data->len ) );
-        TEST_ASSERT( output_size <=
-                     PSA_AEAD_DECRYPT_OUTPUT_MAX_SIZE( input_data->len ) );
-    }
-    ASSERT_ALLOC( output_data, output_size );
-
     PSA_ASSERT( psa_crypto_init( ) );
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_DECRYPT  );
@@ -3167,6 +3166,22 @@
 
     PSA_ASSERT( psa_import_key( &attributes, key_data->x, key_data->len,
                                 &key ) );
+    PSA_ASSERT( psa_get_key_attributes( key, &attributes ) );
+    key_bits = psa_get_key_bits( &attributes );
+
+    output_size = input_data->len - PSA_AEAD_TAG_LENGTH( key_type, key_bits,
+                                                         alg );
+    if( expected_result != PSA_ERROR_INVALID_ARGUMENT &&
+        expected_result != PSA_ERROR_NOT_SUPPORTED )
+    {
+        /* For all currently defined algorithms, PSA_AEAD_DECRYPT_OUTPUT_SIZE
+         * should be exact. */
+        TEST_EQUAL( output_size,
+                    PSA_AEAD_DECRYPT_OUTPUT_SIZE( key_type, alg, input_data->len ) );
+        TEST_ASSERT( output_size <=
+                     PSA_AEAD_DECRYPT_OUTPUT_MAX_SIZE( input_data->len ) );
+    }
+    ASSERT_ALLOC( output_data, output_size );
 
     status = psa_aead_decrypt( key, alg,
                                nonce->x, nonce->len,
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index 20452b7..fc2a8e5 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -822,24 +822,15 @@
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_type_t key_type = key_type_arg;
     psa_algorithm_t alg = alg_arg;
+    size_t key_bits;
     psa_status_t forced_status = forced_status_arg;
     unsigned char *output_data = NULL;
     size_t output_size = 0;
     size_t output_length = 0;
-    size_t tag_length = PSA_AEAD_TAG_LENGTH( alg );
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
     test_driver_aead_hooks = test_driver_aead_hooks_init();
 
-    output_size = input_data->len + tag_length;
-    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
-     * should be exact. */
-    TEST_EQUAL( output_size,
-                PSA_AEAD_ENCRYPT_OUTPUT_SIZE( alg, input_data->len ) );
-    TEST_ASSERT( output_size <=
-                 PSA_AEAD_ENCRYPT_OUTPUT_MAX_SIZE( input_data->len ) );
-    ASSERT_ALLOC( output_data, output_size );
-
     PSA_ASSERT( psa_crypto_init( ) );
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT  );
@@ -848,6 +839,18 @@
 
     PSA_ASSERT( psa_import_key( &attributes, key_data->x, key_data->len,
                                 &key ) );
+    PSA_ASSERT( psa_get_key_attributes( key, &attributes ) );
+    key_bits = psa_get_key_bits( &attributes );
+
+    output_size = input_data->len + PSA_AEAD_TAG_LENGTH( key_type, key_bits,
+                                                         alg );
+    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
+     * should be exact. */
+    TEST_EQUAL( output_size,
+                PSA_AEAD_ENCRYPT_OUTPUT_SIZE( key_type, alg, input_data->len ) );
+    TEST_ASSERT( output_size <=
+                 PSA_AEAD_ENCRYPT_OUTPUT_MAX_SIZE( input_data->len ) );
+    ASSERT_ALLOC( output_data, output_size );
 
     test_driver_aead_hooks.forced_status = forced_status;
     status = psa_aead_encrypt( key, alg,
@@ -888,18 +891,15 @@
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_type_t key_type = key_type_arg;
     psa_algorithm_t alg = alg_arg;
+    size_t key_bits;
     psa_status_t forced_status = forced_status_arg;
     unsigned char *output_data = NULL;
     size_t output_size = 0;
     size_t output_length = 0;
-    size_t tag_length = PSA_AEAD_TAG_LENGTH( alg );
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_status_t status = PSA_ERROR_GENERIC_ERROR;
     test_driver_aead_hooks = test_driver_aead_hooks_init();
 
-    output_size = input_data->len - tag_length;
-    ASSERT_ALLOC( output_data, output_size );
-
     PSA_ASSERT( psa_crypto_init( ) );
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_DECRYPT  );
@@ -908,6 +908,12 @@
 
     PSA_ASSERT( psa_import_key( &attributes, key_data->x, key_data->len,
                                 &key ) );
+    PSA_ASSERT( psa_get_key_attributes( key, &attributes ) );
+    key_bits = psa_get_key_bits( &attributes );
+
+    output_size = input_data->len - PSA_AEAD_TAG_LENGTH( key_type, key_bits,
+                                                         alg );
+    ASSERT_ALLOC( output_data, output_size );
 
     test_driver_aead_hooks.forced_status = forced_status;
     status = psa_aead_decrypt( key, alg,
diff --git a/tests/suites/test_suite_psa_crypto_metadata.data b/tests/suites/test_suite_psa_crypto_metadata.data
index bd98a76..4e2f4d5 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.data
+++ b/tests/suites/test_suite_psa_crypto_metadata.data
@@ -134,17 +134,57 @@
 depends_on:PSA_WANT_ALG_XTS:MBEDTLS_CIPHER_C
 cipher_algorithm:PSA_ALG_XTS:0
 
-AEAD: CCM
-depends_on:PSA_WANT_ALG_CCM
-aead_algorithm:PSA_ALG_CCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16
+AEAD: CCM-AES-128
+depends_on:PSA_WANT_KEY_TYPE_AES:PSA_WANT_ALG_CCM
+aead_algorithm:PSA_ALG_CCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_AES:128
 
-AEAD: GCM
-depends_on:PSA_WANT_ALG_GCM
-aead_algorithm:PSA_ALG_GCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16
+AEAD: CCM-AES-192
+depends_on:PSA_WANT_KEY_TYPE_AES:PSA_WANT_ALG_CCM
+aead_algorithm:PSA_ALG_CCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_AES:192
+
+AEAD: CCM-AES-256
+depends_on:PSA_WANT_KEY_TYPE_AES:PSA_WANT_ALG_CCM
+aead_algorithm:PSA_ALG_CCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_AES:256
+
+AEAD: CCM-CAMELLIA-128
+depends_on:PSA_WANT_KEY_TYPE_CAMELLIA:PSA_WANT_ALG_CCM
+aead_algorithm:PSA_ALG_CCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_CAMELLIA:128
+
+AEAD: CCM-CAMELLIA-192
+depends_on:PSA_WANT_KEY_TYPE_CAMELLIA:PSA_WANT_ALG_CCM
+aead_algorithm:PSA_ALG_CCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_CAMELLIA:192
+
+AEAD: CCM-CAMELLIA-256
+depends_on:PSA_WANT_KEY_TYPE_CAMELLIA:PSA_WANT_ALG_CCM
+aead_algorithm:PSA_ALG_CCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_CAMELLIA:256
+
+AEAD: GCM-AES-128
+depends_on:PSA_WANT_KEY_TYPE_AES:PSA_WANT_ALG_GCM
+aead_algorithm:PSA_ALG_GCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_AES:128
+
+AEAD: GCM-AES-192
+depends_on:PSA_WANT_KEY_TYPE_AES:PSA_WANT_ALG_GCM
+aead_algorithm:PSA_ALG_GCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_AES:192
+
+AEAD: GCM-AES-256
+depends_on:PSA_WANT_KEY_TYPE_AES:PSA_WANT_ALG_GCM
+aead_algorithm:PSA_ALG_GCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_AES:256
+
+AEAD: GCM-CAMELLIA-128
+depends_on:PSA_WANT_KEY_TYPE_CAMELLIA:PSA_WANT_ALG_GCM
+aead_algorithm:PSA_ALG_GCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_CAMELLIA:128
+
+AEAD: GCM-CAMELLIA-192
+depends_on:PSA_WANT_KEY_TYPE_CAMELLIA:PSA_WANT_ALG_GCM
+aead_algorithm:PSA_ALG_GCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_CAMELLIA:192
+
+AEAD: GCM-CAMELLIA-256
+depends_on:PSA_WANT_KEY_TYPE_CAMELLIA:PSA_WANT_ALG_GCM
+aead_algorithm:PSA_ALG_GCM:ALG_IS_AEAD_ON_BLOCK_CIPHER:16:PSA_KEY_TYPE_CAMELLIA:256
 
 AEAD: ChaCha20_Poly1305
 depends_on:PSA_WANT_ALG_CHACHA20_POLY1305
-aead_algorithm:PSA_ALG_CHACHA20_POLY1305:0:16
+aead_algorithm:PSA_ALG_CHACHA20_POLY1305:0:16:PSA_KEY_TYPE_CHACHA20:256
 
 Asymmetric signature: RSA PKCS#1 v1.5 raw
 depends_on:PSA_WANT_ALG_RSA_PKCS1V15_SIGN
diff --git a/tests/suites/test_suite_psa_crypto_metadata.function b/tests/suites/test_suite_psa_crypto_metadata.function
index 8acbe44..8134f44 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.function
+++ b/tests/suites/test_suite_psa_crypto_metadata.function
@@ -169,6 +169,7 @@
 }
 
 void aead_algorithm_core( psa_algorithm_t alg, int classification_flags,
+                          psa_key_type_t key_type, size_t key_bits,
                           size_t tag_length )
 {
     /* Algorithm classification */
@@ -183,7 +184,7 @@
     algorithm_classification( alg, classification_flags );
 
     /* Tag length */
-    TEST_EQUAL( tag_length, PSA_AEAD_TAG_LENGTH( alg ) );
+    TEST_EQUAL( tag_length, PSA_AEAD_TAG_LENGTH( key_type, key_bits, alg ) );
 
 exit: ;
 }
@@ -367,19 +368,24 @@
 
 /* BEGIN_CASE */
 void aead_algorithm( int alg_arg, int classification_flags,
-                     int tag_length_arg )
+                     int tag_length_arg,
+                     int key_type_arg, int key_bits_arg )
 {
     psa_algorithm_t alg = alg_arg;
     size_t tag_length = tag_length_arg;
     size_t n;
+    psa_key_type_t key_type = key_type_arg;
+    size_t key_bits = key_bits_arg;
 
-    aead_algorithm_core( alg, classification_flags, tag_length );
+    aead_algorithm_core( alg, classification_flags,
+                         key_type, key_bits, tag_length );
 
     /* Truncated versions */
     for( n = 1; n <= tag_length; n++ )
     {
         psa_algorithm_t truncated_alg = PSA_ALG_AEAD_WITH_SHORTENED_TAG( alg, n );
-        aead_algorithm_core( truncated_alg, classification_flags, n );
+        aead_algorithm_core( truncated_alg, classification_flags,
+                             key_type, key_bits, n );
         TEST_EQUAL( PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG( truncated_alg ),
                     alg );
         /* Check that calling PSA_ALG_AEAD_WITH_SHORTENED_TAG twice gives
@@ -411,7 +417,8 @@
     for( n = 1; n <= tag_length; n++ )
     {
         psa_algorithm_t policy_alg = PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG( alg, n );
-        aead_algorithm_core( policy_alg, classification_flags | ALG_IS_WILDCARD, n );
+        aead_algorithm_core( policy_alg, classification_flags | ALG_IS_WILDCARD,
+                             key_type, key_bits, n );
         TEST_EQUAL( PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG( policy_alg ),
                     alg );
         /* Check that calling PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG twice