Break up the god class TestGenerator
Use separate classes for information gathering, for each kind of test
generation (currently just one: not-supported), and for writing output
files.
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index aae92d6..6baf53e 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -69,42 +69,14 @@
for dep in dependencies):
dependencies.append('DEPENDENCY_NOT_IMPLEMENTED_YET')
-def test_case_for_key_type_not_supported(
- verb: str, key_type: str, bits: int,
- dependencies: List[str],
- *args: str,
- param_descr: str = ''
-) -> test_case.TestCase:
- """Return one test case exercising a key creation method
- for an unsupported key type or size.
- """
- hack_dependencies_not_implemented(dependencies)
- tc = test_case.TestCase()
- short_key_type = re.sub(r'PSA_(KEY_TYPE|ECC_FAMILY)_', r'', key_type)
- adverb = 'not' if dependencies else 'never'
- if param_descr:
- adverb = param_descr + ' ' + adverb
- tc.set_description('PSA {} {} {}-bit {} supported'
- .format(verb, short_key_type, bits, adverb))
- tc.set_dependencies(dependencies)
- tc.set_function(verb + '_not_supported')
- tc.set_arguments([key_type] + list(args))
- return tc
-class TestGenerator:
- """Gather information and generate test data."""
+class Information:
+ """Gather information about PSA constructors."""
- def __init__(self, options):
- self.test_suite_directory = self.get_option(options, 'directory',
- 'tests/suites')
+ def __init__(self) -> None:
self.constructors = self.read_psa_interface()
@staticmethod
- def get_option(options, name: str, default: T) -> T:
- value = getattr(options, name, None)
- return default if value is None else value
-
- @staticmethod
def remove_unwanted_macros(
constructors: macro_collector.PSAMacroCollector
) -> None:
@@ -126,14 +98,34 @@
self.remove_unwanted_macros(constructors)
return constructors
- def write_test_data_file(self, basename: str,
- test_cases: Iterable[test_case.TestCase]) -> None:
- """Write the test cases to a .data file.
- The output file is ``basename + '.data'`` in the test suite directory.
- """
- filename = os.path.join(self.test_suite_directory, basename + '.data')
- test_case.write_data_file(filename, test_cases)
+def test_case_for_key_type_not_supported(
+ verb: str, key_type: str, bits: int,
+ dependencies: List[str],
+ *args: str,
+ param_descr: str = ''
+) -> test_case.TestCase:
+ """Return one test case exercising a key creation method
+ for an unsupported key type or size.
+ """
+ hack_dependencies_not_implemented(dependencies)
+ tc = test_case.TestCase()
+ short_key_type = re.sub(r'PSA_(KEY_TYPE|ECC_FAMILY)_', r'', key_type)
+ adverb = 'not' if dependencies else 'never'
+ if param_descr:
+ adverb = param_descr + ' ' + adverb
+ tc.set_description('PSA {} {} {}-bit {} supported'
+ .format(verb, short_key_type, bits, adverb))
+ tc.set_dependencies(dependencies)
+ tc.set_function(verb + '_not_supported')
+ tc.set_arguments([key_type] + list(args))
+ return tc
+
+class NotSupported:
+ """Generate test cases for when something is not supported."""
+
+ def __init__(self, info: Information) -> None:
+ self.constructors = info.constructors
ALWAYS_SUPPORTED = frozenset([
'PSA_KEY_TYPE_DERIVE',
@@ -187,7 +179,7 @@
# To be added: derive
return test_cases
- def generate_not_supported(self) -> None:
+ def generate_not_supported(self) -> List[test_case.TestCase]:
"""Generate test cases that exercise the creation of keys of unsupported types."""
test_cases = []
for key_type in sorted(self.constructors.key_types):
@@ -202,13 +194,37 @@
kt, param_descr='type')
test_cases += self.test_cases_for_key_type_not_supported(
kt, 0, param_descr='curve')
+ return test_cases
+
+
+class TestGenerator:
+ """Generate test data."""
+
+ def __init__(self, options) -> None:
+ self.test_suite_directory = self.get_option(options, 'directory',
+ 'tests/suites')
+ self.info = Information()
+
+ @staticmethod
+ def get_option(options, name: str, default: T) -> T:
+ value = getattr(options, name, None)
+ return default if value is None else value
+
+ def write_test_data_file(self, basename: str,
+ test_cases: Iterable[test_case.TestCase]) -> None:
+ """Write the test cases to a .data file.
+
+ The output file is ``basename + '.data'`` in the test suite directory.
+ """
+ filename = os.path.join(self.test_suite_directory, basename + '.data')
+ test_case.write_data_file(filename, test_cases)
+
+ def generate_all(self) -> None:
+ test_cases = NotSupported(self.info).generate_not_supported()
self.write_test_data_file(
'test_suite_psa_crypto_not_supported.generated',
test_cases)
- def generate_all(self):
- self.generate_not_supported()
-
def main(args):
"""Command line entry point."""
parser = argparse.ArgumentParser(description=__doc__)