Merge pull request #6296 from gilles-peskine-arm/test_data_generation-pr_6093_followup

Minor fixes to test_data_generation.py
diff --git a/scripts/abi_check.py b/scripts/abi_check.py
index c228843..ac1d60f 100755
--- a/scripts/abi_check.py
+++ b/scripts/abi_check.py
@@ -113,6 +113,8 @@
 
 import xml.etree.ElementTree as ET
 
+from mbedtls_dev import build_tree
+
 
 class AbiChecker:
     """API and ABI checker."""
@@ -150,11 +152,6 @@
         self.git_command = "git"
         self.make_command = "make"
 
-    @staticmethod
-    def check_repo_path():
-        if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
-            raise Exception("Must be run from Mbed TLS root")
-
     def _setup_logger(self):
         self.log = logging.getLogger()
         if self.verbose:
@@ -540,7 +537,7 @@
     def check_for_abi_changes(self):
         """Generate a report of ABI differences
         between self.old_rev and self.new_rev."""
-        self.check_repo_path()
+        build_tree.check_repo_path()
         if self.check_api or self.check_abi:
             self.check_abi_tools_are_installed()
         self._get_abi_dump_for_ref(self.old_version)
diff --git a/scripts/code_size_compare.py b/scripts/code_size_compare.py
index 0ef438d..af6ddd4 100755
--- a/scripts/code_size_compare.py
+++ b/scripts/code_size_compare.py
@@ -30,6 +30,9 @@
 import subprocess
 import sys
 
+from mbedtls_dev import build_tree
+
+
 class CodeSizeComparison:
     """Compare code size between two Git revisions."""
 
@@ -52,11 +55,6 @@
         self.make_command = "make"
 
     @staticmethod
-    def check_repo_path():
-        if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
-            raise Exception("Must be run from Mbed TLS root")
-
-    @staticmethod
     def validate_revision(revision):
         result = subprocess.check_output(["git", "rev-parse", "--verify",
                                           revision + "^{commit}"], shell=False)
@@ -172,7 +170,7 @@
     def get_comparision_results(self):
         """Compare size of library/*.o between self.old_rev and self.new_rev,
         and generate the result file."""
-        self.check_repo_path()
+        build_tree.check_repo_path()
         self._get_code_size_for_rev(self.old_rev)
         self._get_code_size_for_rev(self.new_rev)
         return self.compare_code_size()
diff --git a/scripts/mbedtls_dev/__init__.py b/scripts/mbedtls_dev/__init__.py
new file mode 100644
index 0000000..15b0d60
--- /dev/null
+++ b/scripts/mbedtls_dev/__init__.py
@@ -0,0 +1,3 @@
+# This file needs to exist to make mbedtls_dev a package.
+# Among other things, this allows modules in this directory to make
+# relative imports.
diff --git a/scripts/mbedtls_dev/build_tree.py b/scripts/mbedtls_dev/build_tree.py
index 3920d0e..f52b785 100644
--- a/scripts/mbedtls_dev/build_tree.py
+++ b/scripts/mbedtls_dev/build_tree.py
@@ -25,6 +25,13 @@
     return all(os.path.isdir(os.path.join(path, subdir))
                for subdir in ['include', 'library', 'programs', 'tests'])
 
+def check_repo_path():
+    """
+    Check that the current working directory is the project root, and throw
+    an exception if not.
+    """
+    if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
+        raise Exception("This script must be run from Mbed TLS root")
 
 def chdir_to_root() -> None:
     """Detect the root of the Mbed TLS source tree and change to it.
diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py
index f52ca9a..1a03321 100644
--- a/scripts/mbedtls_dev/crypto_knowledge.py
+++ b/scripts/mbedtls_dev/crypto_knowledge.py
@@ -22,7 +22,7 @@
 import re
 from typing import FrozenSet, Iterable, List, Optional, Tuple
 
-from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
+from .asymmetric_key_data import ASYMMETRIC_KEY_DATA
 
 
 def short_expression(original: str, level: int = 0) -> str:
diff --git a/scripts/mbedtls_dev/psa_storage.py b/scripts/mbedtls_dev/psa_storage.py
index a06dce1..bae9938 100644
--- a/scripts/mbedtls_dev/psa_storage.py
+++ b/scripts/mbedtls_dev/psa_storage.py
@@ -26,7 +26,7 @@
 from typing import Dict, List, Optional, Set, Union
 import unittest
 
-from mbedtls_dev import c_build_helper
+from . import c_build_helper
 
 
 class Expr:
diff --git a/scripts/mbedtls_dev/test_case.py b/scripts/mbedtls_dev/test_case.py
index d0afa59..8f08703 100644
--- a/scripts/mbedtls_dev/test_case.py
+++ b/scripts/mbedtls_dev/test_case.py
@@ -1,4 +1,4 @@
-"""Library for generating Mbed TLS test data.
+"""Library for constructing an Mbed TLS test case.
 """
 
 # Copyright The Mbed TLS Contributors
@@ -21,7 +21,7 @@
 import sys
 from typing import Iterable, List, Optional
 
-from mbedtls_dev import typing_util
+from . import typing_util
 
 def hex_string(data: bytes) -> str:
     return '"' + binascii.hexlify(data).decode('ascii') + '"'
diff --git a/scripts/mbedtls_dev/test_generation.py b/scripts/mbedtls_dev/test_data_generation.py
similarity index 90%
rename from scripts/mbedtls_dev/test_generation.py
rename to scripts/mbedtls_dev/test_data_generation.py
index a88425f..cdb1c03 100644
--- a/scripts/mbedtls_dev/test_generation.py
+++ b/scripts/mbedtls_dev/test_data_generation.py
@@ -1,4 +1,7 @@
-"""Common test generation classes and main function.
+"""Common code for test data generation.
+
+This module defines classes that are of general use to automatically
+generate .data files for unit tests, as well as a main function.
 
 These are used both by generate_psa_tests.py and generate_bignum_tests.py.
 """
@@ -26,8 +29,8 @@
 from abc import ABCMeta, abstractmethod
 from typing import Callable, Dict, Iterable, Iterator, List, Type, TypeVar
 
-from mbedtls_dev import build_tree
-from mbedtls_dev import test_case
+from . import build_tree
+from . import test_case
 
 T = TypeVar('T') #pylint: disable=invalid-name
 
@@ -138,8 +141,7 @@
 class TestGenerator:
     """Generate test cases and write to data files."""
     def __init__(self, options) -> None:
-        self.test_suite_directory = self.get_option(options, 'directory',
-                                                    'tests/suites')
+        self.test_suite_directory = options.directory
         # Update `targets` with an entry for each child class of BaseTarget.
         # Each entry represents a file generated by the BaseTarget framework,
         # and enables generating the .data files using the CLI.
@@ -148,11 +150,6 @@
             for subclass in BaseTarget.__subclasses__()
         })
 
-    @staticmethod
-    def get_option(options, name: str, default: T) -> T:
-        value = getattr(options, name, None)
-        return default if value is None else value
-
     def filename_for(self, basename: str) -> str:
         """The location of the data file with the specified base name."""
         return posixpath.join(self.test_suite_directory, basename + '.data')
@@ -186,16 +183,24 @@
                         help='List available targets and exit')
     parser.add_argument('--list-for-cmake', action='store_true',
                         help='Print \';\'-separated list of available targets and exit')
+    # If specified explicitly, this option may be a path relative to the
+    # current directory when the script is invoked. The default value
+    # is relative to the mbedtls root, which we don't know yet. So we
+    # can't set a string as the default value here.
     parser.add_argument('--directory', metavar='DIR',
                         help='Output directory (default: tests/suites)')
-    # The `--directory` option is interpreted relative to the directory from
-    # which the script is invoked, but the default is relative to the root of
-    # the mbedtls tree. The default should not be set above, but instead after
-    # `build_tree.chdir_to_root()` is called.
     parser.add_argument('targets', nargs='*', metavar='TARGET',
                         help='Target file to generate (default: all; "-": none)')
     options = parser.parse_args(args)
+
+    # Change to the mbedtls root, to keep things simple. But first, adjust
+    # command line options that might be relative paths.
+    if options.directory is None:
+        options.directory = 'tests/suites'
+    else:
+        options.directory = os.path.abspath(options.directory)
     build_tree.chdir_to_root()
+
     generator = generator_class(options)
     if options.list:
         for name in sorted(generator.targets):
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 57cf977..d89542a 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -16,38 +16,44 @@
 # generated .data files will go there
 file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/suites)
 
-# Get base names for generated files (starting at "suites/")
+# Get base names for generated files
 execute_process(
     COMMAND
         ${MBEDTLS_PYTHON_EXECUTABLE}
         ${CMAKE_CURRENT_SOURCE_DIR}/../tests/scripts/generate_bignum_tests.py
         --list-for-cmake
-        --directory suites
     WORKING_DIRECTORY
         ${CMAKE_CURRENT_SOURCE_DIR}/..
     OUTPUT_VARIABLE
         base_bignum_generated_data_files)
+string(REGEX REPLACE "[^;]*/" ""
+       base_bignum_generated_data_files "${base_bignum_generated_data_files}")
 
 execute_process(
     COMMAND
         ${MBEDTLS_PYTHON_EXECUTABLE}
         ${CMAKE_CURRENT_SOURCE_DIR}/../tests/scripts/generate_psa_tests.py
         --list-for-cmake
-        --directory suites
     WORKING_DIRECTORY
         ${CMAKE_CURRENT_SOURCE_DIR}/..
     OUTPUT_VARIABLE
         base_psa_generated_data_files)
+string(REGEX REPLACE "[^;]*/" ""
+       base_psa_generated_data_files "${base_psa_generated_data_files}")
 
-# Derive generated file paths in the build directory
-set(base_generated_data_files ${base_bignum_generated_data_files} ${base_psa_generated_data_files})
+# Derive generated file paths in the build directory. The generated data
+# files go into the suites/ subdirectory.
+set(base_generated_data_files
+    ${base_bignum_generated_data_files} ${base_psa_generated_data_files})
+string(REGEX REPLACE "([^;]+)" "suites/\\1"
+       all_generated_data_files "${base_generated_data_files}")
 set(bignum_generated_data_files "")
 set(psa_generated_data_files "")
 foreach(file ${base_bignum_generated_data_files})
-    list(APPEND bignum_generated_data_files ${CMAKE_CURRENT_BINARY_DIR}/${file})
+    list(APPEND bignum_generated_data_files ${CMAKE_CURRENT_BINARY_DIR}/suites/${file})
 endforeach()
 foreach(file ${base_psa_generated_data_files})
-    list(APPEND psa_generated_data_files ${CMAKE_CURRENT_BINARY_DIR}/${file})
+    list(APPEND psa_generated_data_files ${CMAKE_CURRENT_BINARY_DIR}/suites/${file})
 endforeach()
 
 if(GEN_FILES)
@@ -63,7 +69,7 @@
         DEPENDS
             ${CMAKE_CURRENT_SOURCE_DIR}/../tests/scripts/generate_bignum_tests.py
             ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/test_case.py
-            ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/test_generation.py
+            ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/test_data_generation.py
     )
     add_custom_command(
         OUTPUT
@@ -80,14 +86,14 @@
             ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/macro_collector.py
             ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/psa_storage.py
             ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/test_case.py
-            ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/test_generation.py
+            ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/mbedtls_dev/test_data_generation.py
             ${CMAKE_CURRENT_SOURCE_DIR}/../include/psa/crypto_config.h
             ${CMAKE_CURRENT_SOURCE_DIR}/../include/psa/crypto_values.h
             ${CMAKE_CURRENT_SOURCE_DIR}/../include/psa/crypto_extra.h
     )
 
 else()
-    foreach(file ${base_generated_data_files})
+    foreach(file ${all_generated_data_files})
         link_to_source(${file})
     endforeach()
 endif()
@@ -210,9 +216,9 @@
 endif(MSVC)
 
 file(GLOB test_suites RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" suites/*.data)
-list(APPEND test_suites ${base_generated_data_files})
+list(APPEND test_suites ${all_generated_data_files})
 # If the generated .data files are present in the source tree, we just added
-# them twice, both through GLOB and through ${base_generated_data_files}.
+# them twice, both through GLOB and through ${all_generated_data_files}.
 list(REMOVE_DUPLICATES test_suites)
 list(SORT test_suites)
 foreach(test_suite ${test_suites})
diff --git a/tests/Makefile b/tests/Makefile
index 8777ae9..57f8855 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -93,7 +93,7 @@
 $(GENERATED_BIGNUM_DATA_FILES): generated_bignum_test_data
 generated_bignum_test_data: scripts/generate_bignum_tests.py
 generated_bignum_test_data: ../scripts/mbedtls_dev/test_case.py
-generated_bignum_test_data: ../scripts/mbedtls_dev/test_generation.py
+generated_bignum_test_data: ../scripts/mbedtls_dev/test_data_generation.py
 generated_bignum_test_data:
 	echo "  Gen   $(GENERATED_BIGNUM_DATA_FILES)"
 	$(PYTHON) scripts/generate_bignum_tests.py
@@ -104,7 +104,7 @@
 generated_psa_test_data: ../scripts/mbedtls_dev/macro_collector.py
 generated_psa_test_data: ../scripts/mbedtls_dev/psa_storage.py
 generated_psa_test_data: ../scripts/mbedtls_dev/test_case.py
-generated_psa_test_data: ../scripts/mbedtls_dev/test_generation.py
+generated_psa_test_data: ../scripts/mbedtls_dev/test_data_generation.py
 ## The generated file only depends on the options that are present in
 ## crypto_config.h, not on which options are set. To avoid regenerating this
 ## file all the time when switching between configurations, don't declare
diff --git a/tests/scripts/check-python-files.sh b/tests/scripts/check-python-files.sh
index dbf0365..35319d3 100755
--- a/tests/scripts/check-python-files.sh
+++ b/tests/scripts/check-python-files.sh
@@ -67,7 +67,7 @@
 fi
 
 echo 'Running pylint ...'
-$PYTHON -m pylint -j 2 scripts/mbedtls_dev/*.py scripts/*.py tests/scripts/*.py || {
+$PYTHON -m pylint scripts/mbedtls_dev/*.py scripts/*.py tests/scripts/*.py || {
     echo >&2 "pylint reported errors"
     ret=1
 }
diff --git a/tests/scripts/check_files.py b/tests/scripts/check_files.py
index a0f5e1f..5c18702 100755
--- a/tests/scripts/check_files.py
+++ b/tests/scripts/check_files.py
@@ -34,6 +34,9 @@
 except ImportError:
     pass
 
+import scripts_path # pylint: disable=unused-import
+from mbedtls_dev import build_tree
+
 
 class FileIssueTracker:
     """Base class for file-wide issue tracking.
@@ -338,7 +341,7 @@
         """Instantiate the sanity checker.
         Check files under the current directory.
         Write a report of issues to log_file."""
-        self.check_repo_path()
+        build_tree.check_repo_path()
         self.logger = None
         self.setup_logger(log_file)
         self.issues_to_check = [
@@ -353,11 +356,6 @@
             MergeArtifactIssueTracker(),
         ]
 
-    @staticmethod
-    def check_repo_path():
-        if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
-            raise Exception("Must be run from Mbed TLS root")
-
     def setup_logger(self, log_file, level=logging.INFO):
         self.logger = logging.getLogger()
         self.logger.setLevel(level)
diff --git a/tests/scripts/check_names.py b/tests/scripts/check_names.py
index e204487..aece1ef 100755
--- a/tests/scripts/check_names.py
+++ b/tests/scripts/check_names.py
@@ -56,6 +56,10 @@
 import subprocess
 import logging
 
+import scripts_path # pylint: disable=unused-import
+from mbedtls_dev import build_tree
+
+
 # Naming patterns to check against. These are defined outside the NameCheck
 # class for ease of modification.
 PUBLIC_MACRO_PATTERN = r"^(MBEDTLS|PSA)_[0-9A-Z_]*[0-9A-Z]$"
@@ -219,7 +223,7 @@
     """
     def __init__(self, log):
         self.log = log
-        self.check_repo_path()
+        build_tree.check_repo_path()
 
         # Memo for storing "glob expression": set(filepaths)
         self.files = {}
@@ -228,15 +232,6 @@
         # Note that "*" can match directory separators in exclude lists.
         self.excluded_files = ["*/bn_mul", "*/compat-2.x.h"]
 
-    @staticmethod
-    def check_repo_path():
-        """
-        Check that the current working directory is the project root, and throw
-        an exception if not.
-        """
-        if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
-            raise Exception("This script must be run from Mbed TLS root")
-
     def comprehensive_parse(self):
         """
         Comprehensive ("default") function to call each parsing function and
diff --git a/tests/scripts/generate_bignum_tests.py b/tests/scripts/generate_bignum_tests.py
index 8a4d281..7f332dc 100755
--- a/tests/scripts/generate_bignum_tests.py
+++ b/tests/scripts/generate_bignum_tests.py
@@ -6,7 +6,7 @@
 
 Class structure:
 
-Child classes of test_generation.BaseTarget (file targets) represent an output
+Child classes of test_data_generation.BaseTarget (file targets) represent an output
 file. These indicate where test cases will be written to, for all subclasses of
 this target. Multiple file targets should not reuse a `target_basename`.
 
@@ -36,7 +36,7 @@
         call `.create_test_case()` to yield the TestCase.
 
 Additional details and other attributes/methods are given in the documentation
-of BaseTarget in test_generation.py.
+of BaseTarget in test_data_generation.py.
 """
 
 # Copyright The Mbed TLS Contributors
@@ -63,7 +63,7 @@
 
 import scripts_path # pylint: disable=unused-import
 from mbedtls_dev import test_case
-from mbedtls_dev import test_generation
+from mbedtls_dev import test_data_generation
 
 T = TypeVar('T') #pylint: disable=invalid-name
 
@@ -74,18 +74,16 @@
     return "\"{}\"".format(val)
 
 def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
-    """Return all pair combinations from input values.
-
-    The return value is cast, as older versions of mypy are unable to derive
-    the specific type returned by itertools.combinations_with_replacement.
-    """
+    """Return all pair combinations from input values."""
+    # The return value is cast, as older versions of mypy are unable to derive
+    # the specific type returned by itertools.combinations_with_replacement.
     return typing.cast(
         List[Tuple[T, T]],
         list(itertools.combinations_with_replacement(values, 2))
     )
 
 
-class BignumTarget(test_generation.BaseTarget, metaclass=ABCMeta):
+class BignumTarget(test_data_generation.BaseTarget, metaclass=ABCMeta):
     #pylint: disable=abstract-method
     """Target for bignum (mpi) test case generation."""
     target_basename = 'test_suite_mpi.generated'
@@ -235,4 +233,4 @@
 
 if __name__ == '__main__':
     # Use the section of the docstring relevant to the CLI as description
-    test_generation.main(sys.argv[1:], "\n".join(__doc__.splitlines()[:4]))
+    test_data_generation.main(sys.argv[1:], "\n".join(__doc__.splitlines()[:4]))
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index c788fd7..2f09007 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -30,7 +30,7 @@
 from mbedtls_dev import macro_collector
 from mbedtls_dev import psa_storage
 from mbedtls_dev import test_case
-from mbedtls_dev import test_generation
+from mbedtls_dev import test_data_generation
 
 
 def psa_want_symbol(name: str) -> str:
@@ -892,7 +892,7 @@
         yield from super().generate_all_keys()
         yield from self.all_keys_for_implicit_usage()
 
-class PSATestGenerator(test_generation.TestGenerator):
+class PSATestGenerator(test_data_generation.TestGenerator):
     """Test generator subclass including PSA targets and info."""
     # Note that targets whose names contain 'test_format' have their content
     # validated by `abi_check.py`.
@@ -917,4 +917,4 @@
         super().generate_target(name, self.info)
 
 if __name__ == '__main__':
-    test_generation.main(sys.argv[1:], __doc__, PSATestGenerator)
+    test_data_generation.main(sys.argv[1:], __doc__, PSATestGenerator)