Merge pull request #7909 from mpg/dh-generate-psa-tests

Enable DH in generate_psa_tests.py
diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py
index 819d92a..45d253b 100644
--- a/scripts/mbedtls_dev/crypto_knowledge.py
+++ b/scripts/mbedtls_dev/crypto_knowledge.py
@@ -34,7 +34,7 @@
     unambiguous, but ad hoc way.
     """
     short = original
-    short = re.sub(r'\bPSA_(?:ALG|ECC_FAMILY|KEY_[A-Z]+)_', r'', short)
+    short = re.sub(r'\bPSA_(?:ALG|DH_FAMILY|ECC_FAMILY|KEY_[A-Z]+)_', r'', short)
     short = re.sub(r' +', r'', short)
     if level >= 1:
         short = re.sub(r'PUBLIC_KEY\b', r'PUB', short)
@@ -138,6 +138,9 @@
         """Whether the key type is for public keys."""
         return self.name.endswith('_PUBLIC_KEY')
 
+    DH_KEY_SIZES = {
+        'PSA_DH_FAMILY_RFC7919': (2048, 3072, 4096, 6144, 8192),
+    } # type: Dict[str, Tuple[int, ...]]
     ECC_KEY_SIZES = {
         'PSA_ECC_FAMILY_SECP_K1': (192, 224, 256),
         'PSA_ECC_FAMILY_SECP_R1': (225, 256, 384, 521),
@@ -175,6 +178,9 @@
         if self.private_type == 'PSA_KEY_TYPE_ECC_KEY_PAIR':
             assert self.params is not None
             return self.ECC_KEY_SIZES[self.params[0]]
+        if self.private_type == 'PSA_KEY_TYPE_DH_KEY_PAIR':
+            assert self.params is not None
+            return self.DH_KEY_SIZES[self.params[0]]
         return self.KEY_TYPE_SIZES[self.private_type]
 
     # "48657265006973206b6579a064617461"
@@ -261,6 +267,8 @@
             if alg.head in {'PURE_EDDSA', 'EDDSA_PREHASH'} and \
                eccc == EllipticCurveCategory.TWISTED_EDWARDS:
                 return True
+        if self.head == 'DH' and alg.head == 'FFDH':
+            return True
         return False
 
 
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index f5d83c6..cad7884 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -111,7 +111,7 @@
         _implemented_dependencies = \
             read_implemented_dependencies('include/psa/crypto_config.h')
     if not all((dep.lstrip('!') in _implemented_dependencies or
-                'PSA_WANT' not in dep)
+                not dep.lstrip('!').startswith('PSA_WANT'))
                for dep in dependencies):
         dependencies.append('DEPENDENCY_NOT_IMPLEMENTED_YET')
 
@@ -121,7 +121,14 @@
     symbols according to the required usage.
     """
     ret_list = list()
-    if dep.endswith('KEY_PAIR'):
+    # Note: this LEGACY replacement DH is temporary and it's going
+    # to be aligned with ECC one in #7773.
+    if dep.endswith('DH_KEY_PAIR'):
+        legacy = dep
+        legacy = re.sub(r'KEY_PAIR\Z', r'KEY_PAIR_LEGACY', legacy)
+        legacy = re.sub(r'PSA_WANT', r'MBEDTLS_PSA_WANT', legacy)
+        ret_list.append(legacy)
+    elif dep.endswith('KEY_PAIR'):
         if usage == "BASIC":
             # BASIC automatically includes IMPORT and EXPORT for test purposes (see
             # config_psa.h).
@@ -152,10 +159,8 @@
     def remove_unwanted_macros(
             constructors: macro_collector.PSAMacroEnumerator
     ) -> None:
-        # Mbed TLS doesn't support finite-field DH yet and will not support
-        # finite-field DSA. Don't attempt to generate any related test case.
-        constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR')
-        constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY')
+        # Mbed TLS does not support finite-field DSA.
+        # Don't attempt to generate any related test case.
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR')
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY')
 
@@ -261,12 +266,16 @@
 
     ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
                      'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
 
     def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
         """Generate test cases that exercise the creation of keys of unsupported types."""
         for key_type in sorted(self.constructors.key_types):
             if key_type in self.ECC_KEY_TYPES:
                 continue
+            if key_type in self.DH_KEY_TYPES:
+                continue
             kt = crypto_knowledge.KeyType(key_type)
             yield from self.test_cases_for_key_type_not_supported(kt)
         for curve_family in sorted(self.constructors.ecc_curves):
@@ -276,6 +285,13 @@
                     kt, param_descr='type')
                 yield from self.test_cases_for_key_type_not_supported(
                     kt, 0, param_descr='curve')
+        for dh_family in sorted(self.constructors.dh_groups):
+            for constr in self.DH_KEY_TYPES:
+                kt = crypto_knowledge.KeyType(constr, [dh_family])
+                yield from self.test_cases_for_key_type_not_supported(
+                    kt, param_descr='type')
+                yield from self.test_cases_for_key_type_not_supported(
+                    kt, 0, param_descr='group')
 
 def test_case_for_key_generation(
         key_type: str, bits: int,
@@ -304,6 +320,8 @@
 
     ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
                      'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
 
     @staticmethod
     def test_cases_for_key_type_key_generation(
@@ -341,12 +359,18 @@
         for key_type in sorted(self.constructors.key_types):
             if key_type in self.ECC_KEY_TYPES:
                 continue
+            if key_type in self.DH_KEY_TYPES:
+                continue
             kt = crypto_knowledge.KeyType(key_type)
             yield from self.test_cases_for_key_type_key_generation(kt)
         for curve_family in sorted(self.constructors.ecc_curves):
             for constr in self.ECC_KEY_TYPES:
                 kt = crypto_knowledge.KeyType(constr, [curve_family])
                 yield from self.test_cases_for_key_type_key_generation(kt)
+        for dh_family in sorted(self.constructors.dh_groups):
+            for constr in self.DH_KEY_TYPES:
+                kt = crypto_knowledge.KeyType(constr, [dh_family])
+                yield from self.test_cases_for_key_type_key_generation(kt)
 
 class OpFail:
     """Generate test cases for operations that must fail."""