Systematically generate test cases for operation setup failure

The test suite test_suite_psa_crypto_op_fail now runs a large number
of automatically generated test cases which attempt to perform a
one-shot operation or to set up a multi-part operation with invalid
parameters. The following cases are fully covered (based on the
enumeration of valid algorithms and key types):
* An algorithm is not supported.
* The key type is not compatible with the algorithm (for operations
  that use a key).
* The algorithm is not compatible for the operation.

Some test functions allow the library to return PSA_ERROR_NOT_SUPPORTED
where the test code generator expects PSA_ERROR_INVALID_ARGUMENT or vice
versa. This may be refined in the future.

Some corner cases with algorithms combining a key agreement with a key
derivation are not handled properly. This will be fixed in follow-up
commits.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index e7819b9..f82a1e5 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -21,6 +21,7 @@
 # limitations under the License.
 
 import argparse
+import enum
 import os
 import posixpath
 import re
@@ -309,39 +310,126 @@
     """Generate test cases for operations that must fail."""
     #pylint: disable=too-few-public-methods
 
+    class Reason(enum.Enum):
+        NOT_SUPPORTED = 0
+        INVALID = 1
+        INCOMPATIBLE = 2
+
     def __init__(self, info: Information) -> None:
         self.constructors = info.constructors
+        key_type_expressions = self.constructors.generate_expressions(
+            sorted(self.constructors.key_types)
+        )
+        self.key_types = [crypto_knowledge.KeyType(kt_expr)
+                          for kt_expr in key_type_expressions]
 
-    @staticmethod
-    def hash_test_cases(alg: str) -> Iterator[test_case.TestCase]:
-        """Generate hash failure test cases for the specified algorithm."""
+    def make_test_case(
+            self,
+            alg: crypto_knowledge.Algorithm,
+            category: crypto_knowledge.AlgorithmCategory,
+            reason: 'Reason',
+            kt: Optional[crypto_knowledge.KeyType] = None,
+            not_deps: FrozenSet[str] = frozenset(),
+    ) -> test_case.TestCase:
+        """Construct a failure test case for a one-key or keyless operation."""
+        #pylint: disable=too-many-arguments,too-many-locals
         tc = test_case.TestCase()
-        is_hash = (alg.startswith('PSA_ALG_SHA') or
-                   alg.startswith('PSA_ALG_MD') or
-                   alg in frozenset(['PSA_ALG_RIPEMD160', 'PSA_ALG_ANY_HASH']))
-        if is_hash:
-            descr = 'not supported'
-            status = 'PSA_ERROR_NOT_SUPPORTED'
-            dependencies = ['!PSA_WANT_' + alg[4:]]
+        pretty_alg = re.sub(r'PSA_ALG_', r'', alg.expression)
+        pretty_reason = reason.name.lower()
+        if kt:
+            key_type = kt.expression
+            pretty_type = re.sub(r'PSA_KEY_TYPE_', r'', key_type)
         else:
-            descr = 'invalid'
-            status = 'PSA_ERROR_INVALID_ARGUMENT'
-            dependencies = automatic_dependencies(alg)
-        tc.set_description('PSA hash {}: {}'
-                           .format(descr, re.sub(r'PSA_ALG_', r'', alg)))
+            key_type = ''
+            pretty_type = ''
+        tc.set_description('PSA {} {}: {}{}'
+                           .format(category.name.lower(),
+                                   pretty_alg,
+                                   pretty_reason,
+                                   ' with ' + pretty_type if pretty_type else ''))
+        dependencies = automatic_dependencies(alg.base_expression, key_type)
+        for i, dep in enumerate(dependencies):
+            if dep in not_deps:
+                dependencies[i] = '!' + dep
         tc.set_dependencies(dependencies)
-        tc.set_function('hash_fail')
-        tc.set_arguments([alg, status])
-        yield tc
+        tc.set_function(category.name.lower() + '_fail')
+        arguments = []
+        if kt:
+            key_material = kt.key_material(kt.sizes_to_test()[0])
+            arguments += [key_type, test_case.hex_string(key_material)]
+        arguments.append(alg.expression)
+        error = ('NOT_SUPPORTED' if reason == self.Reason.NOT_SUPPORTED else
+                 'INVALID_ARGUMENT')
+        arguments.append('PSA_ERROR_' + error)
+        tc.set_arguments(arguments)
+        return tc
 
-    def test_cases_for_algorithm(self, alg: str) -> Iterator[test_case.TestCase]:
+    def no_key_test_cases(
+            self,
+            alg: crypto_knowledge.Algorithm,
+            category: crypto_knowledge.AlgorithmCategory,
+    ) -> Iterator[test_case.TestCase]:
+        """Generate failure test cases for keyless operations with the specified algorithm."""
+        if category == alg.category:
+            # Compatible operation, unsupported algorithm
+            for dep in automatic_dependencies(alg.base_expression):
+                yield self.make_test_case(alg, category,
+                                          self.Reason.NOT_SUPPORTED,
+                                          not_deps=frozenset([dep]))
+        else:
+            # Incompatible operation, supported algorithm
+            yield self.make_test_case(alg, category, self.Reason.INVALID)
+
+    def one_key_test_cases(
+            self,
+            alg: crypto_knowledge.Algorithm,
+            category: crypto_knowledge.AlgorithmCategory,
+    ) -> Iterator[test_case.TestCase]:
+        """Generate failure test cases for one-key operations with the specified algorithm."""
+        for kt in self.key_types:
+            key_is_compatible = kt.can_do(alg)
+            # To do: public key for a private key operation
+            if key_is_compatible and category == alg.category:
+                # Compatible key and operation, unsupported algorithm
+                for dep in automatic_dependencies(alg.base_expression):
+                    yield self.make_test_case(alg, category,
+                                              self.Reason.NOT_SUPPORTED,
+                                              kt=kt, not_deps=frozenset([dep]))
+            elif key_is_compatible:
+                # Compatible key, incompatible operation, supported algorithm
+                yield self.make_test_case(alg, category,
+                                          self.Reason.INVALID,
+                                          kt=kt)
+            elif category == alg.category:
+                # Incompatible key, compatible operation, supported algorithm
+                yield self.make_test_case(alg, category,
+                                          self.Reason.INCOMPATIBLE,
+                                          kt=kt)
+            else:
+                # Incompatible key and operation. Don't test cases where
+                # multiple things are wrong, to keep the number of test
+                # cases reasonable.
+                pass
+
+    def test_cases_for_algorithm(
+            self,
+            alg: crypto_knowledge.Algorithm,
+    ) -> Iterator[test_case.TestCase]:
         """Generate operation failure test cases for the specified algorithm."""
-        yield from self.hash_test_cases(alg)
+        for category in crypto_knowledge.AlgorithmCategory:
+            if category == crypto_knowledge.AlgorithmCategory.PAKE:
+                # PAKE operations are not implemented yet
+                pass
+            elif category.requires_key():
+                yield from self.one_key_test_cases(alg, category)
+            else:
+                yield from self.no_key_test_cases(alg, category)
 
     def all_test_cases(self) -> Iterator[test_case.TestCase]:
         """Generate all test cases for operations that must fail."""
         algorithms = sorted(self.constructors.algorithms)
-        for alg in self.constructors.generate_expressions(algorithms):
+        for expr in self.constructors.generate_expressions(algorithms):
+            alg = crypto_knowledge.Algorithm(expr)
             yield from self.test_cases_for_algorithm(alg)
 
 
diff --git a/tests/suites/test_suite_psa_crypto_op_fail.function b/tests/suites/test_suite_psa_crypto_op_fail.function
index 21dbafa..4640649 100644
--- a/tests/suites/test_suite_psa_crypto_op_fail.function
+++ b/tests/suites/test_suite_psa_crypto_op_fail.function
@@ -3,6 +3,37 @@
 #include "psa/crypto.h"
 #include "test/psa_crypto_helpers.h"
 
+static int test_equal_status( const char *test,
+                              int line_no, const char* filename,
+                              psa_status_t value1,
+                              psa_status_t value2 )
+{
+    if( ( value1 == PSA_ERROR_INVALID_ARGUMENT &&
+          value2 == PSA_ERROR_NOT_SUPPORTED ) ||
+        ( value1 == PSA_ERROR_NOT_SUPPORTED &&
+          value2 == PSA_ERROR_INVALID_ARGUMENT ) )
+    {
+        return( 1 );
+    }
+    return( mbedtls_test_equal( test, line_no, filename, value1, value2 ) );
+}
+
+/** Like #TEST_EQUAL, but expects #psa_status_t values and treats
+ * #PSA_ERROR_INVALID_ARGUMENT and #PSA_ERROR_NOT_SUPPORTED as
+ * interchangeable.
+ *
+ * This test suite currently allows NOT_SUPPORTED and INVALID_ARGUMENT
+ * to be interchangeable in places where the library's behavior does not
+ * match the strict expectations of the test case generator. In the long
+ * run, it would be better to clarify the expectations and reconcile the
+ * library and the test case generator.
+ */
+#define TEST_STATUS( expr1, expr2 )                                     \
+    do {                                                                \
+        if( ! test_equal_status( #expr1 " == " #expr2, __LINE__, __FILE__, \
+                                 expr1, expr2 ) )                       \
+            goto exit;                                                  \
+    } while( 0 )
 
 /* END_HEADER */
 
@@ -37,3 +68,301 @@
     PSA_DONE( );
 }
 /* END_CASE */
+
+/* BEGIN_CASE */
+void mac_fail( int key_type_arg, data_t *key_data,
+               int alg_arg, int expected_status_arg )
+{
+    psa_status_t expected_status = expected_status_arg;
+    psa_key_type_t key_type = key_type_arg;
+    psa_algorithm_t alg = alg_arg;
+    psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    uint8_t input[1] = {'A'};
+    uint8_t output[PSA_MAC_MAX_SIZE] = {0};
+    size_t length = SIZE_MAX;
+
+    PSA_INIT( );
+
+    psa_set_key_type( &attributes, key_type );
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_SIGN_HASH |
+                             PSA_KEY_USAGE_VERIFY_HASH );
+    psa_set_key_algorithm( &attributes, alg );
+    PSA_ASSERT( psa_import_key( &attributes,
+                                key_data->x, key_data->len,
+                                &key_id ) );
+
+    TEST_STATUS( expected_status,
+                 psa_mac_sign_setup( &operation, key_id, alg ) );
+    TEST_STATUS( expected_status,
+                 psa_mac_verify_setup( &operation, key_id, alg ) );
+    TEST_STATUS( expected_status,
+                 psa_mac_compute( key_id, alg,
+                                  input, sizeof( input ),
+                                  output, sizeof( output ), &length ) );
+    TEST_STATUS( expected_status,
+                 psa_mac_verify( key_id, alg,
+                                 input, sizeof( input ),
+                                 output, sizeof( output ) ) );
+
+exit:
+    psa_mac_abort( &operation );
+    psa_destroy_key( key_id );
+    psa_reset_key_attributes( &attributes );
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void cipher_fail( int key_type_arg, data_t *key_data,
+                  int alg_arg, int expected_status_arg )
+{
+    psa_status_t expected_status = expected_status_arg;
+    psa_key_type_t key_type = key_type_arg;
+    psa_algorithm_t alg = alg_arg;
+    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    uint8_t input[1] = {'A'};
+    uint8_t output[64] = {0};
+    size_t length = SIZE_MAX;
+
+    PSA_INIT( );
+
+    psa_set_key_type( &attributes, key_type );
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_ENCRYPT |
+                             PSA_KEY_USAGE_DECRYPT );
+    psa_set_key_algorithm( &attributes, alg );
+    PSA_ASSERT( psa_import_key( &attributes,
+                                key_data->x, key_data->len,
+                                &key_id ) );
+
+    TEST_STATUS( expected_status,
+                 psa_cipher_encrypt_setup( &operation, key_id, alg ) );
+    TEST_STATUS( expected_status,
+                 psa_cipher_decrypt_setup( &operation, key_id, alg ) );
+    TEST_STATUS( expected_status,
+                 psa_cipher_encrypt( key_id, alg,
+                                     input, sizeof( input ),
+                                     output, sizeof( output ), &length ) );
+    TEST_STATUS( expected_status,
+                 psa_cipher_decrypt( key_id, alg,
+                                     input, sizeof( input ),
+                                     output, sizeof( output ), &length ) );
+
+exit:
+    psa_cipher_abort( &operation );
+    psa_destroy_key( key_id );
+    psa_reset_key_attributes( &attributes );
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void aead_fail( int key_type_arg, data_t *key_data,
+                int alg_arg, int expected_status_arg )
+{
+    psa_status_t expected_status = expected_status_arg;
+    psa_key_type_t key_type = key_type_arg;
+    psa_algorithm_t alg = alg_arg;
+    psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    uint8_t input[16] = "ABCDEFGHIJKLMNO";
+    uint8_t output[64] = {0};
+    size_t length = SIZE_MAX;
+
+    PSA_INIT( );
+
+    psa_set_key_type( &attributes, key_type );
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_ENCRYPT |
+                             PSA_KEY_USAGE_DECRYPT );
+    psa_set_key_algorithm( &attributes, alg );
+    PSA_ASSERT( psa_import_key( &attributes,
+                                key_data->x, key_data->len,
+                                &key_id ) );
+
+    TEST_STATUS( expected_status,
+                 psa_aead_encrypt_setup( &operation, key_id, alg ) );
+    TEST_STATUS( expected_status,
+                 psa_aead_decrypt_setup( &operation, key_id, alg ) );
+    TEST_STATUS( expected_status,
+                 psa_aead_encrypt( key_id, alg,
+                                   input, sizeof( input ),
+                                   NULL, 0, input, sizeof( input ),
+                                   output, sizeof( output ), &length ) );
+    TEST_STATUS( expected_status,
+                 psa_aead_decrypt( key_id, alg,
+                                   input, sizeof( input ),
+                                   NULL, 0, input, sizeof( input ),
+                                   output, sizeof( output ), &length ) );
+
+exit:
+    psa_aead_abort( &operation );
+    psa_destroy_key( key_id );
+    psa_reset_key_attributes( &attributes );
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void sign_fail( int key_type_arg, data_t *key_data,
+                int alg_arg, int expected_status_arg )
+{
+    psa_status_t expected_status = expected_status_arg;
+    psa_key_type_t key_type = key_type_arg;
+    psa_algorithm_t alg = alg_arg;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    uint8_t input[1] = {'A'};
+    uint8_t output[PSA_SIGNATURE_MAX_SIZE] = {0};
+    size_t length = SIZE_MAX;
+
+    PSA_INIT( );
+
+    psa_set_key_type( &attributes, key_type );
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_SIGN_HASH |
+                             PSA_KEY_USAGE_VERIFY_HASH );
+    psa_set_key_algorithm( &attributes, alg );
+    PSA_ASSERT( psa_import_key( &attributes,
+                                key_data->x, key_data->len,
+                                &key_id ) );
+
+    TEST_STATUS( expected_status,
+                 psa_sign_hash( key_id, alg,
+                                input, sizeof( input ),
+                                output, sizeof( output ), &length ) );
+    TEST_STATUS( expected_status,
+                 psa_verify_hash( key_id, alg,
+                                  input, sizeof( input ),
+                                  output, sizeof( output ) ) );
+
+exit:
+    psa_destroy_key( key_id );
+    psa_reset_key_attributes( &attributes );
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void asymmetric_encryption_fail( int key_type_arg, data_t *key_data,
+                                 int alg_arg, int expected_status_arg )
+{
+    psa_status_t expected_status = expected_status_arg;
+    psa_key_type_t key_type = key_type_arg;
+    psa_algorithm_t alg = alg_arg;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    uint8_t plaintext[PSA_ASYMMETRIC_DECRYPT_OUTPUT_MAX_SIZE] = {0};
+    uint8_t ciphertext[PSA_ASYMMETRIC_ENCRYPT_OUTPUT_MAX_SIZE] = {0};
+    size_t length = SIZE_MAX;
+
+    PSA_INIT( );
+
+    psa_set_key_type( &attributes, key_type );
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_ENCRYPT |
+                             PSA_KEY_USAGE_DECRYPT );
+    psa_set_key_algorithm( &attributes, alg );
+    PSA_ASSERT( psa_import_key( &attributes,
+                                key_data->x, key_data->len,
+                                &key_id ) );
+
+    TEST_STATUS( expected_status,
+                 psa_asymmetric_encrypt( key_id, alg,
+                                         plaintext, 1,
+                                         NULL, 0,
+                                         ciphertext, sizeof( ciphertext ),
+                                         &length ) );
+    TEST_STATUS( expected_status,
+                 psa_asymmetric_decrypt( key_id, alg,
+                                         ciphertext, sizeof( ciphertext ),
+                                         NULL, 0,
+                                         plaintext, sizeof( plaintext ),
+                                         &length ) );
+
+exit:
+    psa_destroy_key( key_id );
+    psa_reset_key_attributes( &attributes );
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void key_derivation_fail( int alg_arg, int expected_status_arg )
+{
+    psa_status_t expected_status = expected_status_arg;
+    psa_algorithm_t alg = alg_arg;
+    psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;
+
+    PSA_INIT( );
+
+    TEST_EQUAL( expected_status,
+                psa_key_derivation_setup( &operation, alg ) );
+
+exit:
+    psa_key_derivation_abort( &operation );
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void key_agreement_fail( int key_type_arg, data_t *key_data,
+                         int alg_arg, int expected_status_arg )
+{
+    psa_status_t expected_status = expected_status_arg;
+    psa_key_type_t key_type = key_type_arg;
+    psa_algorithm_t alg = alg_arg;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    uint8_t public_key[PSA_EXPORT_PUBLIC_KEY_MAX_SIZE] = {0};
+    size_t public_key_length = SIZE_MAX;
+    uint8_t output[PSA_SIGNATURE_MAX_SIZE] = {0};
+    size_t length = SIZE_MAX;
+    psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;
+
+    PSA_INIT( );
+
+    psa_set_key_type( &attributes, key_type );
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_DERIVE );
+    psa_set_key_algorithm( &attributes, alg );
+    PSA_ASSERT( psa_import_key( &attributes,
+                                key_data->x, key_data->len,
+                                &key_id ) );
+    if( PSA_KEY_TYPE_IS_KEY_PAIR( key_type ) ||
+        PSA_KEY_TYPE_IS_PUBLIC_KEY( key_type ) )
+    {
+        PSA_ASSERT( psa_export_public_key( key_id,
+                                           public_key, sizeof( public_key ),
+                                           &public_key_length ) );
+    }
+
+    TEST_STATUS( expected_status,
+                 psa_raw_key_agreement( alg, key_id,
+                                        public_key, public_key_length,
+                                        output, sizeof( output ), &length ) );
+
+#if defined(PSA_WANT_ALG_HKDF) && defined(PSA_WANT_ALG_SHA_256)
+    PSA_ASSERT( psa_key_derivation_setup( &operation,
+                                          PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ) );
+    TEST_STATUS( expected_status,
+                 psa_key_derivation_key_agreement(
+                     &operation,
+                     PSA_KEY_DERIVATION_INPUT_SECRET,
+                     key_id,
+                     public_key, public_key_length ) );
+#endif
+
+exit:
+    psa_key_derivation_abort( &operation );
+    psa_destroy_key( key_id );
+    psa_reset_key_attributes( &attributes );
+    PSA_DONE( );
+}
+/* END_CASE */