Split out PSAMacroEnumerator from the test data collection code

Split out the code that enumerates constructors of a PSA crypto type
from the code used to populate the list of constructors for the
specific purpose of testing psa_constant_names.

This commit adds some documentation but otherwise strives to minimize
code changes.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 9e8d7f8..9795e05 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -33,6 +33,110 @@
 import scripts_path # pylint: disable=unused-import
 from mbedtls_dev import c_build_helper
 
+class PSAMacroEnumerator:
+    """Information about constructors of various PSA Crypto types.
+
+    This includes macro names as well as information about their arguments
+    when applicable.
+
+    This class only provides ways to enumerate expressions that evaluate to
+    values of the covered types. Derived classes are expected to populate
+    the set of known constructors of each kind, as well as populate
+    `self.arguments_for` for arguments that are not of a kind that is
+    enumerated here.
+    """
+
+    def __init__(self):
+        """Set up an empty set of known constructor macros.
+        """
+        self.statuses = set()
+        self.algorithms = set()
+        self.ecc_curves = set()
+        self.dh_groups = set()
+        self.key_types = set()
+        self.key_usage_flags = set()
+        self.hash_algorithms = set()
+        self.mac_algorithms = set()
+        self.ka_algorithms = set()
+        self.kdf_algorithms = set()
+        self.aead_algorithms = set()
+        # macro name -> list of argument names
+        self.argspecs = {}
+        # argument name -> list of values
+        self.arguments_for = {
+            'mac_length': [],
+            'min_mac_length': [],
+            'tag_length': [],
+            'min_tag_length': [],
+        }
+
+    def gather_arguments(self):
+        """Populate the list of values for macro arguments.
+
+        Call this after parsing all the inputs.
+        """
+        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
+        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
+        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
+        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
+        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
+        self.arguments_for['curve'] = sorted(self.ecc_curves)
+        self.arguments_for['group'] = sorted(self.dh_groups)
+
+    @staticmethod
+    def _format_arguments(name, arguments):
+        """Format a macro call with arguments.."""
+        return name + '(' + ', '.join(arguments) + ')'
+
+    _argument_split_re = re.compile(r' *, *')
+    @classmethod
+    def _argument_split(cls, arguments):
+        return re.split(cls._argument_split_re, arguments)
+
+    def distribute_arguments(self, name):
+        """Generate macro calls with each tested argument set.
+
+        If name is a macro without arguments, just yield "name".
+        If name is a macro with arguments, yield a series of
+        "name(arg1,...,argN)" where each argument takes each possible
+        value at least once.
+        """
+        try:
+            if name not in self.argspecs:
+                yield name
+                return
+            argspec = self.argspecs[name]
+            if argspec == []:
+                yield name + '()'
+                return
+            argument_lists = [self.arguments_for[arg] for arg in argspec]
+            arguments = [values[0] for values in argument_lists]
+            yield self._format_arguments(name, arguments)
+            # Dear Pylint, enumerate won't work here since we're modifying
+            # the array.
+            # pylint: disable=consider-using-enumerate
+            for i in range(len(arguments)):
+                for value in argument_lists[i][1:]:
+                    arguments[i] = value
+                    yield self._format_arguments(name, arguments)
+                arguments[i] = argument_lists[0][0]
+        except BaseException as e:
+            raise Exception('distribute_arguments({})'.format(name)) from e
+
+    def generate_expressions(self, names):
+        """Generate expressions covering values constructed from the given names.
+
+        `names` can be any iterable collection of macro names.
+
+        For example:
+        * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
+          generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
+          every known hash algorithm ``h``.
+        * ``macros.generate_expressions(macros.key_types)`` generates all
+          key types.
+        """
+        return itertools.chain(*map(self.distribute_arguments, names))
+
 class ReadFileLineException(Exception):
     def __init__(self, filename, line_number):
         message = 'in {} at {}'.format(filename, line_number)
@@ -78,7 +182,7 @@
             raise ReadFileLineException(self.filename, self.line_number) \
                 from exc_value
 
-class Inputs:
+class InputsForTest(PSAMacroEnumerator):
     # pylint: disable=too-many-instance-attributes
     """Accumulate information about macros to test.
 
@@ -87,27 +191,29 @@
     """
 
     def __init__(self):
+        super().__init__()
         self.all_declared = set()
         # Sets of names per type
-        self.statuses = set(['PSA_SUCCESS'])
-        self.algorithms = set(['0xffffffff'])
-        self.ecc_curves = set(['0xff'])
-        self.dh_groups = set(['0xff'])
-        self.key_types = set(['0xffff'])
-        self.key_usage_flags = set(['0x80000000'])
+        self.statuses.add('PSA_SUCCESS')
+        self.algorithms.add('0xffffffff')
+        self.ecc_curves.add('0xff')
+        self.dh_groups.add('0xff')
+        self.key_types.add('0xffff')
+        self.key_usage_flags.add('0x80000000')
+
         # Hard-coded values for unknown algorithms
         #
         # These have to have values that are correct for their respective
         # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
         # not likely to be assigned in the near future.
-        self.hash_algorithms = set(['0x020000fe']) # 0x020000ff is PSA_ALG_ANY_HASH
-        self.mac_algorithms = set(['0x03007fff'])
-        self.ka_algorithms = set(['0x09fc0000'])
-        self.kdf_algorithms = set(['0x080000ff'])
+        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
+        self.mac_algorithms.add('0x03007fff')
+        self.ka_algorithms.add('0x09fc0000')
+        self.kdf_algorithms.add('0x080000ff')
         # For AEAD algorithms, the only variability is over the tag length,
         # and this only applies to known algorithms, so don't test an
         # unknown algorithm.
-        self.aead_algorithms = set()
+
         # Identifier prefixes
         self.table_by_prefix = {
             'ERROR': self.statuses,
@@ -140,15 +246,10 @@
             'asymmetric_encryption_algorithm': [],
             'other_algorithm': [],
         }
-        # macro name -> list of argument names
-        self.argspecs = {}
-        # argument name -> list of values
-        self.arguments_for = {
-            'mac_length': ['1', '63'],
-            'tag_length': ['1', '63'],
-            'min_mac_length': ['1', '63'],
-            'min_tag_length': ['1', '63'],
-        }
+        self.arguments_for['mac_length'] += ['1', '63']
+        self.arguments_for['min_mac_length'] += ['1', '63']
+        self.arguments_for['tag_length'] += ['1', '63']
+        self.arguments_for['min_tag_length'] += ['1', '63']
 
     def get_names(self, type_word):
         """Return the set of known names of values of the given type."""
@@ -161,62 +262,6 @@
             'key_usage': self.key_usage_flags,
         }[type_word]
 
-    def gather_arguments(self):
-        """Populate the list of values for macro arguments.
-
-        Call this after parsing all the inputs.
-        """
-        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
-        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
-        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
-        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
-        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
-        self.arguments_for['curve'] = sorted(self.ecc_curves)
-        self.arguments_for['group'] = sorted(self.dh_groups)
-
-    @staticmethod
-    def _format_arguments(name, arguments):
-        """Format a macro call with arguments.."""
-        return name + '(' + ', '.join(arguments) + ')'
-
-    def distribute_arguments(self, name):
-        """Generate macro calls with each tested argument set.
-
-        If name is a macro without arguments, just yield "name".
-        If name is a macro with arguments, yield a series of
-        "name(arg1,...,argN)" where each argument takes each possible
-        value at least once.
-        """
-        try:
-            if name not in self.argspecs:
-                yield name
-                return
-            argspec = self.argspecs[name]
-            if argspec == []:
-                yield name + '()'
-                return
-            argument_lists = [self.arguments_for[arg] for arg in argspec]
-            arguments = [values[0] for values in argument_lists]
-            yield self._format_arguments(name, arguments)
-            # Dear Pylint, enumerate won't work here since we're modifying
-            # the array.
-            # pylint: disable=consider-using-enumerate
-            for i in range(len(arguments)):
-                for value in argument_lists[i][1:]:
-                    arguments[i] = value
-                    yield self._format_arguments(name, arguments)
-                arguments[i] = argument_lists[0][0]
-        except BaseException as e:
-            raise Exception('distribute_arguments({})'.format(name)) from e
-
-    def generate_expressions(self, names):
-        return itertools.chain(*map(self.distribute_arguments, names))
-
-    _argument_split_re = re.compile(r' *, *')
-    @classmethod
-    def _argument_split(cls, arguments):
-        return re.split(cls._argument_split_re, arguments)
-
     # Regex for interesting header lines.
     # Groups: 1=macro name, 2=type, 3=argument list (optional).
     _header_line_re = \
@@ -301,7 +346,7 @@
                 if m:
                     self.add_test_case_line(m.group(1), m.group(2))
 
-def gather_inputs(headers, test_suites, inputs_class=Inputs):
+def gather_inputs(headers, test_suites, inputs_class=InputsForTest):
     """Read the list of inputs to test psa_constant_names with."""
     inputs = inputs_class()
     for header in headers: