Enable DH in generate_psa_tests.py

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index f5d83c6..738136c 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -152,10 +152,8 @@
     def remove_unwanted_macros(
             constructors: macro_collector.PSAMacroEnumerator
     ) -> None:
-        # Mbed TLS doesn't support finite-field DH yet and will not support
-        # finite-field DSA. Don't attempt to generate any related test case.
-        constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR')
-        constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY')
+        # Mbed TLS does not support finite-field DSA.
+        # Don't attempt to generate any related test case.
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR')
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY')
 
@@ -261,12 +259,16 @@
 
     ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
                      'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
 
     def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
         """Generate test cases that exercise the creation of keys of unsupported types."""
         for key_type in sorted(self.constructors.key_types):
             if key_type in self.ECC_KEY_TYPES:
                 continue
+            if key_type in self.DH_KEY_TYPES:
+                continue
             kt = crypto_knowledge.KeyType(key_type)
             yield from self.test_cases_for_key_type_not_supported(kt)
         for curve_family in sorted(self.constructors.ecc_curves):
@@ -276,6 +278,13 @@
                     kt, param_descr='type')
                 yield from self.test_cases_for_key_type_not_supported(
                     kt, 0, param_descr='curve')
+        for dh_family in sorted(self.constructors.dh_groups):
+            for constr in self.DH_KEY_TYPES:
+                kt = crypto_knowledge.KeyType(constr, [dh_family])
+                yield from self.test_cases_for_key_type_not_supported(
+                    kt, param_descr='type')
+                yield from self.test_cases_for_key_type_not_supported(
+                    kt, 0, param_descr='group')
 
 def test_case_for_key_generation(
         key_type: str, bits: int,
@@ -304,6 +313,8 @@
 
     ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
                      'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
 
     @staticmethod
     def test_cases_for_key_type_key_generation(
@@ -341,12 +352,18 @@
         for key_type in sorted(self.constructors.key_types):
             if key_type in self.ECC_KEY_TYPES:
                 continue
+            if key_type in self.DH_KEY_TYPES:
+                continue
             kt = crypto_knowledge.KeyType(key_type)
             yield from self.test_cases_for_key_type_key_generation(kt)
         for curve_family in sorted(self.constructors.ecc_curves):
             for constr in self.ECC_KEY_TYPES:
                 kt = crypto_knowledge.KeyType(constr, [curve_family])
                 yield from self.test_cases_for_key_type_key_generation(kt)
+        for dh_family in sorted(self.constructors.dh_groups):
+            for constr in self.DH_KEY_TYPES:
+                kt = crypto_knowledge.KeyType(constr, [dh_family])
+                yield from self.test_cases_for_key_type_key_generation(kt)
 
 class OpFail:
     """Generate test cases for operations that must fail."""