Move PSAMacroEnumerator to macro_collector

It's useful for more than test_psa_constant_names.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py
index 7ebd8f7..c9e6ec3 100644
--- a/scripts/mbedtls_dev/macro_collector.py
+++ b/scripts/mbedtls_dev/macro_collector.py
@@ -16,8 +16,115 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import itertools
 import re
-from typing import Dict, Set
+from typing import Dict, Iterable, Iterator, List, Set
+
+
+class PSAMacroEnumerator:
+    """Information about constructors of various PSA Crypto types.
+
+    This includes macro names as well as information about their arguments
+    when applicable.
+
+    This class only provides ways to enumerate expressions that evaluate to
+    values of the covered types. Derived classes are expected to populate
+    the set of known constructors of each kind, as well as populate
+    `self.arguments_for` for arguments that are not of a kind that is
+    enumerated here.
+    """
+
+    def __init__(self) -> None:
+        """Set up an empty set of known constructor macros.
+        """
+        self.statuses = set() #type: Set[str]
+        self.algorithms = set() #type: Set[str]
+        self.ecc_curves = set() #type: Set[str]
+        self.dh_groups = set() #type: Set[str]
+        self.key_types = set() #type: Set[str]
+        self.key_usage_flags = set() #type: Set[str]
+        self.hash_algorithms = set() #type: Set[str]
+        self.mac_algorithms = set() #type: Set[str]
+        self.ka_algorithms = set() #type: Set[str]
+        self.kdf_algorithms = set() #type: Set[str]
+        self.aead_algorithms = set() #type: Set[str]
+        # macro name -> list of argument names
+        self.argspecs = {} #type: Dict[str, List[str]]
+        # argument name -> list of values
+        self.arguments_for = {
+            'mac_length': [],
+            'min_mac_length': [],
+            'tag_length': [],
+            'min_tag_length': [],
+        } #type: Dict[str, List[str]]
+
+    def gather_arguments(self) -> None:
+        """Populate the list of values for macro arguments.
+
+        Call this after parsing all the inputs.
+        """
+        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
+        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
+        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
+        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
+        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
+        self.arguments_for['curve'] = sorted(self.ecc_curves)
+        self.arguments_for['group'] = sorted(self.dh_groups)
+
+    @staticmethod
+    def _format_arguments(name: str, arguments: Iterable[str]) -> str:
+        """Format a macro call with arguments.."""
+        return name + '(' + ', '.join(arguments) + ')'
+
+    _argument_split_re = re.compile(r' *, *')
+    @classmethod
+    def _argument_split(cls, arguments: str) -> List[str]:
+        return re.split(cls._argument_split_re, arguments)
+
+    def distribute_arguments(self, name: str) -> Iterator[str]:
+        """Generate macro calls with each tested argument set.
+
+        If name is a macro without arguments, just yield "name".
+        If name is a macro with arguments, yield a series of
+        "name(arg1,...,argN)" where each argument takes each possible
+        value at least once.
+        """
+        try:
+            if name not in self.argspecs:
+                yield name
+                return
+            argspec = self.argspecs[name]
+            if argspec == []:
+                yield name + '()'
+                return
+            argument_lists = [self.arguments_for[arg] for arg in argspec]
+            arguments = [values[0] for values in argument_lists]
+            yield self._format_arguments(name, arguments)
+            # Dear Pylint, enumerate won't work here since we're modifying
+            # the array.
+            # pylint: disable=consider-using-enumerate
+            for i in range(len(arguments)):
+                for value in argument_lists[i][1:]:
+                    arguments[i] = value
+                    yield self._format_arguments(name, arguments)
+                arguments[i] = argument_lists[0][0]
+        except BaseException as e:
+            raise Exception('distribute_arguments({})'.format(name)) from e
+
+    def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
+        """Generate expressions covering values constructed from the given names.
+
+        `names` can be any iterable collection of macro names.
+
+        For example:
+        * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
+          generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
+          every known hash algorithm ``h``.
+        * ``macros.generate_expressions(macros.key_types)`` generates all
+          key types.
+        """
+        return itertools.chain(*map(self.distribute_arguments, names))
+
 
 class PSAMacroCollector:
     """Collect PSA crypto macro definitions from C header files.
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 21a2e8d..b3fdb8d 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -24,119 +24,14 @@
 
 import argparse
 from collections import namedtuple
-import itertools
 import os
 import re
 import subprocess
 import sys
-from typing import Dict, Iterable, Iterator, List, Set
 
 import scripts_path # pylint: disable=unused-import
 from mbedtls_dev import c_build_helper
-
-class PSAMacroEnumerator:
-    """Information about constructors of various PSA Crypto types.
-
-    This includes macro names as well as information about their arguments
-    when applicable.
-
-    This class only provides ways to enumerate expressions that evaluate to
-    values of the covered types. Derived classes are expected to populate
-    the set of known constructors of each kind, as well as populate
-    `self.arguments_for` for arguments that are not of a kind that is
-    enumerated here.
-    """
-
-    def __init__(self) -> None:
-        """Set up an empty set of known constructor macros.
-        """
-        self.statuses = set() #type: Set[str]
-        self.algorithms = set() #type: Set[str]
-        self.ecc_curves = set() #type: Set[str]
-        self.dh_groups = set() #type: Set[str]
-        self.key_types = set() #type: Set[str]
-        self.key_usage_flags = set() #type: Set[str]
-        self.hash_algorithms = set() #type: Set[str]
-        self.mac_algorithms = set() #type: Set[str]
-        self.ka_algorithms = set() #type: Set[str]
-        self.kdf_algorithms = set() #type: Set[str]
-        self.aead_algorithms = set() #type: Set[str]
-        # macro name -> list of argument names
-        self.argspecs = {} #type: Dict[str, List[str]]
-        # argument name -> list of values
-        self.arguments_for = {
-            'mac_length': [],
-            'min_mac_length': [],
-            'tag_length': [],
-            'min_tag_length': [],
-        } #type: Dict[str, List[str]]
-
-    def gather_arguments(self) -> None:
-        """Populate the list of values for macro arguments.
-
-        Call this after parsing all the inputs.
-        """
-        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
-        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
-        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
-        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
-        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
-        self.arguments_for['curve'] = sorted(self.ecc_curves)
-        self.arguments_for['group'] = sorted(self.dh_groups)
-
-    @staticmethod
-    def _format_arguments(name: str, arguments: Iterable[str]) -> str:
-        """Format a macro call with arguments.."""
-        return name + '(' + ', '.join(arguments) + ')'
-
-    _argument_split_re = re.compile(r' *, *')
-    @classmethod
-    def _argument_split(cls, arguments: str) -> List[str]:
-        return re.split(cls._argument_split_re, arguments)
-
-    def distribute_arguments(self, name: str) -> Iterator[str]:
-        """Generate macro calls with each tested argument set.
-
-        If name is a macro without arguments, just yield "name".
-        If name is a macro with arguments, yield a series of
-        "name(arg1,...,argN)" where each argument takes each possible
-        value at least once.
-        """
-        try:
-            if name not in self.argspecs:
-                yield name
-                return
-            argspec = self.argspecs[name]
-            if argspec == []:
-                yield name + '()'
-                return
-            argument_lists = [self.arguments_for[arg] for arg in argspec]
-            arguments = [values[0] for values in argument_lists]
-            yield self._format_arguments(name, arguments)
-            # Dear Pylint, enumerate won't work here since we're modifying
-            # the array.
-            # pylint: disable=consider-using-enumerate
-            for i in range(len(arguments)):
-                for value in argument_lists[i][1:]:
-                    arguments[i] = value
-                    yield self._format_arguments(name, arguments)
-                arguments[i] = argument_lists[0][0]
-        except BaseException as e:
-            raise Exception('distribute_arguments({})'.format(name)) from e
-
-    def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
-        """Generate expressions covering values constructed from the given names.
-
-        `names` can be any iterable collection of macro names.
-
-        For example:
-        * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
-          generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
-          every known hash algorithm ``h``.
-        * ``macros.generate_expressions(macros.key_types)`` generates all
-          key types.
-        """
-        return itertools.chain(*map(self.distribute_arguments, names))
+from mbedtls_dev import macro_collector
 
 class ReadFileLineException(Exception):
     def __init__(self, filename, line_number):
@@ -183,7 +78,7 @@
             raise ReadFileLineException(self.filename, self.line_number) \
                 from exc_value
 
-class InputsForTest(PSAMacroEnumerator):
+class InputsForTest(macro_collector.PSAMacroEnumerator):
     # pylint: disable=too-many-instance-attributes
     """Accumulate information about macros to test.