Return the same error in multipart and single shot AEAD

psa_aead_encrypt_setup() and psa_aead_decrypt_setup() were returning
PSA_ERROR_INVALID_ARGUMENT, while the same failed checks were producing
PSA_ERROR_NOT_SUPPORTED if they happened in psa_aead_encrypt() or
psa_aead_decrypt().

The PSA Crypto API 1.1 spec will specify PSA_ERROR_INVALID_ARGUMENT
in the case that the supplied algorithm is not an AEAD one.

Also move these shared checks to a helper function, to reduce code
duplication and ensure that the functions remain in sync.

Signed-off-by: Bence Szépkúti <bence.szepkuti@arm.com>
diff --git a/ChangeLog.d/psa_aead_singleshot_error.txt b/ChangeLog.d/psa_aead_singleshot_error.txt
new file mode 100644
index 0000000..7243874
--- /dev/null
+++ b/ChangeLog.d/psa_aead_singleshot_error.txt
@@ -0,0 +1,4 @@
+Changes
+   * Return PSA_ERROR_INVALID_ARGUMENT if the algorithm passed to singleshot
+     AEAD functions is not an AEAD algorithm. This aligns them with the
+     multipart functions, and the PSA Crypto API 1.1 spec.
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 829ed45..dbff133 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3719,6 +3719,13 @@
     return( PSA_ERROR_INVALID_ARGUMENT );
 }
 
+static psa_status_t psa_aead_initial_checks( psa_algorithm_t alg ) {
+    if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
+        return( PSA_ERROR_INVALID_ARGUMENT );
+
+    return( PSA_SUCCESS );
+}
+
 psa_status_t psa_aead_encrypt( mbedtls_svc_key_id_t key,
                                psa_algorithm_t alg,
                                const uint8_t *nonce,
@@ -3736,8 +3743,9 @@
 
     *ciphertext_length = 0;
 
-    if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
-        return( PSA_ERROR_NOT_SUPPORTED );
+    status = psa_aead_initial_checks( alg );
+    if( status != PSA_SUCCESS )
+        return( status );
 
     status = psa_get_and_lock_key_slot_with_policy(
                  key, &slot, PSA_KEY_USAGE_ENCRYPT, alg );
@@ -3786,8 +3794,9 @@
 
     *plaintext_length = 0;
 
-    if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
-        return( PSA_ERROR_NOT_SUPPORTED );
+    status = psa_aead_initial_checks( alg );
+    if( status != PSA_SUCCESS )
+        return( status );
 
     status = psa_get_and_lock_key_slot_with_policy(
                  key, &slot, PSA_KEY_USAGE_DECRYPT, alg );
@@ -3830,11 +3839,9 @@
     psa_key_slot_t *slot = NULL;
     psa_key_usage_t key_usage = 0;
 
-    if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
-    {
-        status = PSA_ERROR_INVALID_ARGUMENT;
+    status = psa_aead_initial_checks( alg );
+    if( status != PSA_SUCCESS )
         goto exit;
-    }
 
     if( operation->id != 0 )
     {
diff --git a/tests/scripts/test_psa_compliance.py b/tests/scripts/test_psa_compliance.py
index 942fd79..da5229b 100755
--- a/tests/scripts/test_psa_compliance.py
+++ b/tests/scripts/test_psa_compliance.py
@@ -47,7 +47,7 @@
 #
 # Web URL: https://github.com/bensze01/psa-arch-tests/tree/fixes-for-mbedtls-3
 PSA_ARCH_TESTS_REPO = 'https://github.com/bensze01/psa-arch-tests.git'
-PSA_ARCH_TESTS_REF = 'fix-multipart-aead'
+PSA_ARCH_TESTS_REF = 'fix-pr-5272'
 
 #pylint: disable=too-many-branches,too-many-statements
 def main():
diff --git a/tests/suites/test_suite_psa_crypto.data b/tests/suites/test_suite_psa_crypto.data
index c45f9f0..5aade05 100644
--- a/tests/suites/test_suite_psa_crypto.data
+++ b/tests/suites/test_suite_psa_crypto.data
@@ -842,7 +842,7 @@
 
 PSA key policy: AEAD, min-length policy used as algorithm
 depends_on:PSA_WANT_ALG_CCM:PSA_WANT_KEY_TYPE_AES
-aead_key_policy:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(PSA_ALG_CCM, 8):PSA_KEY_TYPE_AES:"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa":13:8:PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(PSA_ALG_CCM, 8):PSA_ERROR_NOT_SUPPORTED
+aead_key_policy:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(PSA_ALG_CCM, 8):PSA_KEY_TYPE_AES:"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa":13:8:PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(PSA_ALG_CCM, 8):PSA_ERROR_INVALID_ARGUMENT
 
 PSA key policy: AEAD, tag length > exact-length policy
 depends_on:PSA_WANT_ALG_CCM:PSA_WANT_KEY_TYPE_AES
@@ -2829,11 +2829,11 @@
 
 PSA AEAD encrypt/decrypt: invalid algorithm (CTR)
 depends_on:MBEDTLS_AES_C:MBEDTLS_GCM_C
-aead_encrypt_decrypt:PSA_KEY_TYPE_AES:"D7828D13B2B0BDC325A76236DF93CC6B":PSA_ALG_CTR:"000102030405060708090A0B0C0D0E0F":"":"":PSA_ERROR_NOT_SUPPORTED
+aead_encrypt_decrypt:PSA_KEY_TYPE_AES:"D7828D13B2B0BDC325A76236DF93CC6B":PSA_ALG_CTR:"000102030405060708090A0B0C0D0E0F":"":"":PSA_ERROR_INVALID_ARGUMENT
 
 PSA AEAD encrypt/decrypt: invalid algorithm (ChaCha20)
 depends_on:MBEDTLS_CHACHA20_C
-aead_encrypt_decrypt:PSA_KEY_TYPE_CHACHA20:"808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f":PSA_ALG_STREAM_CIPHER:"":"":"":PSA_ERROR_NOT_SUPPORTED
+aead_encrypt_decrypt:PSA_KEY_TYPE_CHACHA20:"808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f":PSA_ALG_STREAM_CIPHER:"":"":"":PSA_ERROR_INVALID_ARGUMENT
 
 PSA Multipart AEAD encrypt: AES - CCM, 23 bytes (lengths set)
 depends_on:PSA_WANT_ALG_CCM:PSA_WANT_KEY_TYPE_AES