Enable DH in generate_psa_tests.py
Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index f5d83c6..738136c 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -152,10 +152,8 @@
def remove_unwanted_macros(
constructors: macro_collector.PSAMacroEnumerator
) -> None:
- # Mbed TLS doesn't support finite-field DH yet and will not support
- # finite-field DSA. Don't attempt to generate any related test case.
- constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR')
- constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY')
+ # Mbed TLS does not support finite-field DSA.
+ # Don't attempt to generate any related test case.
constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR')
constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY')
@@ -261,12 +259,16 @@
ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+ DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+ 'PSA_KEY_TYPE_DH_PUBLIC_KEY')
def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
"""Generate test cases that exercise the creation of keys of unsupported types."""
for key_type in sorted(self.constructors.key_types):
if key_type in self.ECC_KEY_TYPES:
continue
+ if key_type in self.DH_KEY_TYPES:
+ continue
kt = crypto_knowledge.KeyType(key_type)
yield from self.test_cases_for_key_type_not_supported(kt)
for curve_family in sorted(self.constructors.ecc_curves):
@@ -276,6 +278,13 @@
kt, param_descr='type')
yield from self.test_cases_for_key_type_not_supported(
kt, 0, param_descr='curve')
+ for dh_family in sorted(self.constructors.dh_groups):
+ for constr in self.DH_KEY_TYPES:
+ kt = crypto_knowledge.KeyType(constr, [dh_family])
+ yield from self.test_cases_for_key_type_not_supported(
+ kt, param_descr='type')
+ yield from self.test_cases_for_key_type_not_supported(
+ kt, 0, param_descr='group')
def test_case_for_key_generation(
key_type: str, bits: int,
@@ -304,6 +313,8 @@
ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+ DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
+ 'PSA_KEY_TYPE_DH_PUBLIC_KEY')
@staticmethod
def test_cases_for_key_type_key_generation(
@@ -341,12 +352,18 @@
for key_type in sorted(self.constructors.key_types):
if key_type in self.ECC_KEY_TYPES:
continue
+ if key_type in self.DH_KEY_TYPES:
+ continue
kt = crypto_knowledge.KeyType(key_type)
yield from self.test_cases_for_key_type_key_generation(kt)
for curve_family in sorted(self.constructors.ecc_curves):
for constr in self.ECC_KEY_TYPES:
kt = crypto_knowledge.KeyType(constr, [curve_family])
yield from self.test_cases_for_key_type_key_generation(kt)
+ for dh_family in sorted(self.constructors.dh_groups):
+ for constr in self.DH_KEY_TYPES:
+ kt = crypto_knowledge.KeyType(constr, [dh_family])
+ yield from self.test_cases_for_key_type_key_generation(kt)
class OpFail:
"""Generate test cases for operations that must fail."""