Enable DH in generate_psa_tests.py

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py
index 819d92a..eab6f56 100644
--- a/scripts/mbedtls_dev/crypto_knowledge.py
+++ b/scripts/mbedtls_dev/crypto_knowledge.py
@@ -138,6 +138,9 @@
         """Whether the key type is for public keys."""
         return self.name.endswith('_PUBLIC_KEY')
 
+    DH_KEY_SIZES = {
+        'PSA_DH_FAMILY_RFC7919': (2048, 3072, 4096, 6144, 8192),
+    } # type: Dict[str, Tuple[int, ...]]
     ECC_KEY_SIZES = {
         'PSA_ECC_FAMILY_SECP_K1': (192, 224, 256),
         'PSA_ECC_FAMILY_SECP_R1': (225, 256, 384, 521),
@@ -175,6 +178,9 @@
         if self.private_type == 'PSA_KEY_TYPE_ECC_KEY_PAIR':
             assert self.params is not None
             return self.ECC_KEY_SIZES[self.params[0]]
+        if self.private_type == 'PSA_KEY_TYPE_DH_KEY_PAIR':
+            assert self.params is not None
+            return self.DH_KEY_SIZES[self.params[0]]
         return self.KEY_TYPE_SIZES[self.private_type]
 
     # "48657265006973206b6579a064617461"
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index f5d83c6..738136c 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -152,10 +152,8 @@
     def remove_unwanted_macros(
             constructors: macro_collector.PSAMacroEnumerator
     ) -> None:
-        # Mbed TLS doesn't support finite-field DH yet and will not support
-        # finite-field DSA. Don't attempt to generate any related test case.
-        constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR')
-        constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY')
+        # Mbed TLS does not support finite-field DSA.
+        # Don't attempt to generate any related test case.
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR')
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY')
 
@@ -261,12 +259,16 @@
 
     ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
                      'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
 
     def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
         """Generate test cases that exercise the creation of keys of unsupported types."""
         for key_type in sorted(self.constructors.key_types):
             if key_type in self.ECC_KEY_TYPES:
                 continue
+            if key_type in self.DH_KEY_TYPES:
+                continue
             kt = crypto_knowledge.KeyType(key_type)
             yield from self.test_cases_for_key_type_not_supported(kt)
         for curve_family in sorted(self.constructors.ecc_curves):
@@ -276,6 +278,13 @@
                     kt, param_descr='type')
                 yield from self.test_cases_for_key_type_not_supported(
                     kt, 0, param_descr='curve')
+        for dh_family in sorted(self.constructors.dh_groups):
+            for constr in self.DH_KEY_TYPES:
+                kt = crypto_knowledge.KeyType(constr, [dh_family])
+                yield from self.test_cases_for_key_type_not_supported(
+                    kt, param_descr='type')
+                yield from self.test_cases_for_key_type_not_supported(
+                    kt, 0, param_descr='group')
 
 def test_case_for_key_generation(
         key_type: str, bits: int,
@@ -304,6 +313,8 @@
 
     ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
                      'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
 
     @staticmethod
     def test_cases_for_key_type_key_generation(
@@ -341,12 +352,18 @@
         for key_type in sorted(self.constructors.key_types):
             if key_type in self.ECC_KEY_TYPES:
                 continue
+            if key_type in self.DH_KEY_TYPES:
+                continue
             kt = crypto_knowledge.KeyType(key_type)
             yield from self.test_cases_for_key_type_key_generation(kt)
         for curve_family in sorted(self.constructors.ecc_curves):
             for constr in self.ECC_KEY_TYPES:
                 kt = crypto_knowledge.KeyType(constr, [curve_family])
                 yield from self.test_cases_for_key_type_key_generation(kt)
+        for dh_family in sorted(self.constructors.dh_groups):
+            for constr in self.DH_KEY_TYPES:
+                kt = crypto_knowledge.KeyType(constr, [dh_family])
+                yield from self.test_cases_for_key_type_key_generation(kt)
 
 class OpFail:
     """Generate test cases for operations that must fail."""