Merge pull request #6424 from gilles-peskine-arm/test_data_generation-pr_6093_followup-2.28
Backport 2.28: Minor fixes to test_data_generation.py
diff --git a/scripts/abi_check.py b/scripts/abi_check.py
index c228843..ac1d60f 100755
--- a/scripts/abi_check.py
+++ b/scripts/abi_check.py
@@ -113,6 +113,8 @@
import xml.etree.ElementTree as ET
+from mbedtls_dev import build_tree
+
class AbiChecker:
"""API and ABI checker."""
@@ -150,11 +152,6 @@
self.git_command = "git"
self.make_command = "make"
- @staticmethod
- def check_repo_path():
- if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
- raise Exception("Must be run from Mbed TLS root")
-
def _setup_logger(self):
self.log = logging.getLogger()
if self.verbose:
@@ -540,7 +537,7 @@
def check_for_abi_changes(self):
"""Generate a report of ABI differences
between self.old_rev and self.new_rev."""
- self.check_repo_path()
+ build_tree.check_repo_path()
if self.check_api or self.check_abi:
self.check_abi_tools_are_installed()
self._get_abi_dump_for_ref(self.old_version)
diff --git a/scripts/mbedtls_dev/__init__.py b/scripts/mbedtls_dev/__init__.py
new file mode 100644
index 0000000..15b0d60
--- /dev/null
+++ b/scripts/mbedtls_dev/__init__.py
@@ -0,0 +1,3 @@
+# This file needs to exist to make mbedtls_dev a package.
+# Among other things, this allows modules in this directory to make
+# relative imports.
diff --git a/scripts/mbedtls_dev/build_tree.py b/scripts/mbedtls_dev/build_tree.py
new file mode 100644
index 0000000..f52b785
--- /dev/null
+++ b/scripts/mbedtls_dev/build_tree.py
@@ -0,0 +1,67 @@
+"""Mbed TLS build tree information and manipulation.
+"""
+
+# Copyright The Mbed TLS Contributors
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import inspect
+
+
+def looks_like_mbedtls_root(path: str) -> bool:
+ """Whether the given directory looks like the root of the Mbed TLS source tree."""
+ return all(os.path.isdir(os.path.join(path, subdir))
+ for subdir in ['include', 'library', 'programs', 'tests'])
+
+def check_repo_path():
+ """
+ Check that the current working directory is the project root, and throw
+ an exception if not.
+ """
+ if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
+ raise Exception("This script must be run from Mbed TLS root")
+
+def chdir_to_root() -> None:
+ """Detect the root of the Mbed TLS source tree and change to it.
+
+ The current directory must be up to two levels deep inside an Mbed TLS
+ source tree.
+ """
+ for d in [os.path.curdir,
+ os.path.pardir,
+ os.path.join(os.path.pardir, os.path.pardir)]:
+ if looks_like_mbedtls_root(d):
+ os.chdir(d)
+ return
+ raise Exception('Mbed TLS source tree not found')
+
+
+def guess_mbedtls_root():
+ """Guess mbedTLS source code directory.
+
+ Return the first possible mbedTLS root directory
+ """
+ dirs = set({})
+ for frame in inspect.stack():
+ path = os.path.dirname(frame.filename)
+ for d in ['.', os.path.pardir] \
+ + [os.path.join(*([os.path.pardir]*i)) for i in range(2, 10)]:
+ d = os.path.abspath(os.path.join(path, d))
+ if d in dirs:
+ continue
+ dirs.add(d)
+ if looks_like_mbedtls_root(d):
+ return d
+ raise Exception('Mbed TLS source tree not found')
diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py
index 2173d10..f227a41 100644
--- a/scripts/mbedtls_dev/crypto_knowledge.py
+++ b/scripts/mbedtls_dev/crypto_knowledge.py
@@ -22,7 +22,7 @@
import re
from typing import FrozenSet, Iterable, List, Optional, Tuple
-from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
+from .asymmetric_key_data import ASYMMETRIC_KEY_DATA
def short_expression(original: str, level: int = 0) -> str:
diff --git a/scripts/mbedtls_dev/psa_storage.py b/scripts/mbedtls_dev/psa_storage.py
index a06dce1..bae9938 100644
--- a/scripts/mbedtls_dev/psa_storage.py
+++ b/scripts/mbedtls_dev/psa_storage.py
@@ -26,7 +26,7 @@
from typing import Dict, List, Optional, Set, Union
import unittest
-from mbedtls_dev import c_build_helper
+from . import c_build_helper
class Expr:
diff --git a/scripts/mbedtls_dev/test_case.py b/scripts/mbedtls_dev/test_case.py
index d0afa59..8f08703 100644
--- a/scripts/mbedtls_dev/test_case.py
+++ b/scripts/mbedtls_dev/test_case.py
@@ -1,4 +1,4 @@
-"""Library for generating Mbed TLS test data.
+"""Library for constructing an Mbed TLS test case.
"""
# Copyright The Mbed TLS Contributors
@@ -21,7 +21,7 @@
import sys
from typing import Iterable, List, Optional
-from mbedtls_dev import typing_util
+from . import typing_util
def hex_string(data: bytes) -> str:
return '"' + binascii.hexlify(data).decode('ascii') + '"'
diff --git a/scripts/mbedtls_dev/test_generation.py b/scripts/mbedtls_dev/test_data_generation.py
similarity index 93%
rename from scripts/mbedtls_dev/test_generation.py
rename to scripts/mbedtls_dev/test_data_generation.py
index 5de0b88..9e36af3 100644
--- a/scripts/mbedtls_dev/test_generation.py
+++ b/scripts/mbedtls_dev/test_data_generation.py
@@ -1,4 +1,7 @@
-"""Common test generation classes and main function.
+"""Common code for test data generation.
+
+This module defines classes that are of general use to automatically
+generate .data files for unit tests, as well as a main function.
These are used both by generate_psa_tests.py and generate_bignum_tests.py.
"""
@@ -26,7 +29,8 @@
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, Iterable, Iterator, List, Type, TypeVar
-from mbedtls_dev import test_case
+from . import build_tree
+from . import test_case
T = TypeVar('T') #pylint: disable=invalid-name
@@ -136,9 +140,8 @@
class TestGenerator:
"""Generate test cases and write to data files."""
- def __init__(self, options) -> None:
- self.test_suite_directory = self.get_option(options, 'directory',
- 'tests/suites')
+ def __init__(self, _options) -> None:
+ self.test_suite_directory = 'tests/suites'
# Update `targets` with an entry for each child class of BaseTarget.
# Each entry represents a file generated by the BaseTarget framework,
# and enables generating the .data files using the CLI.
@@ -147,11 +150,6 @@
for subclass in BaseTarget.__subclasses__()
})
- @staticmethod
- def get_option(options, name: str, default: T) -> T:
- value = getattr(options, name, None)
- return default if value is None else value
-
def filename_for(self, basename: str) -> str:
"""The location of the data file with the specified base name."""
return posixpath.join(self.test_suite_directory, basename + '.data')
@@ -185,6 +183,12 @@
help='List available targets and exit')
parser.add_argument('targets', nargs='*', metavar='TARGET',
help='Target file to generate (default: all; "-": none)')
+
+ # Change to the mbedtls root, to keep things simple.
+ # Note that if any command line options refer to paths, they need to
+ # be adjusted first.
+ build_tree.chdir_to_root()
+
options = parser.parse_args(args)
generator = generator_class(options)
if options.list:
diff --git a/tests/scripts/check-python-files.sh b/tests/scripts/check-python-files.sh
index dbf0365..35319d3 100755
--- a/tests/scripts/check-python-files.sh
+++ b/tests/scripts/check-python-files.sh
@@ -67,7 +67,7 @@
fi
echo 'Running pylint ...'
-$PYTHON -m pylint -j 2 scripts/mbedtls_dev/*.py scripts/*.py tests/scripts/*.py || {
+$PYTHON -m pylint scripts/mbedtls_dev/*.py scripts/*.py tests/scripts/*.py || {
echo >&2 "pylint reported errors"
ret=1
}
diff --git a/tests/scripts/check_files.py b/tests/scripts/check_files.py
index a0f5e1f..5c18702 100755
--- a/tests/scripts/check_files.py
+++ b/tests/scripts/check_files.py
@@ -34,6 +34,9 @@
except ImportError:
pass
+import scripts_path # pylint: disable=unused-import
+from mbedtls_dev import build_tree
+
class FileIssueTracker:
"""Base class for file-wide issue tracking.
@@ -338,7 +341,7 @@
"""Instantiate the sanity checker.
Check files under the current directory.
Write a report of issues to log_file."""
- self.check_repo_path()
+ build_tree.check_repo_path()
self.logger = None
self.setup_logger(log_file)
self.issues_to_check = [
@@ -353,11 +356,6 @@
MergeArtifactIssueTracker(),
]
- @staticmethod
- def check_repo_path():
- if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
- raise Exception("Must be run from Mbed TLS root")
-
def setup_logger(self, log_file, level=logging.INFO):
self.logger = logging.getLogger()
self.logger.setLevel(level)
diff --git a/tests/scripts/check_names.py b/tests/scripts/check_names.py
index 875d0b0f..d1e87b5 100755
--- a/tests/scripts/check_names.py
+++ b/tests/scripts/check_names.py
@@ -56,6 +56,10 @@
import subprocess
import logging
+import scripts_path # pylint: disable=unused-import
+from mbedtls_dev import build_tree
+
+
# Naming patterns to check against. These are defined outside the NameCheck
# class for ease of modification.
MACRO_PATTERN = r"^(MBEDTLS|PSA)_[0-9A-Z_]*[0-9A-Z]$"
@@ -218,7 +222,7 @@
"""
def __init__(self, log):
self.log = log
- self.check_repo_path()
+ build_tree.check_repo_path()
# Memo for storing "glob expression": set(filepaths)
self.files = {}
@@ -227,15 +231,6 @@
# Note that "*" can match directory separators in exclude lists.
self.excluded_files = ["*/bn_mul", "*/compat-1.3.h"]
- @staticmethod
- def check_repo_path():
- """
- Check that the current working directory is the project root, and throw
- an exception if not.
- """
- if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
- raise Exception("This script must be run from Mbed TLS root")
-
def comprehensive_parse(self):
"""
Comprehensive ("default") function to call each parsing function and
diff --git a/tests/scripts/generate_bignum_tests.py b/tests/scripts/generate_bignum_tests.py
index ceafa4a..091630d 100755
--- a/tests/scripts/generate_bignum_tests.py
+++ b/tests/scripts/generate_bignum_tests.py
@@ -6,7 +6,7 @@
Class structure:
-Child classes of test_generation.BaseTarget (file targets) represent an output
+Child classes of test_data_generation.BaseTarget (file targets) represent an output
file. These indicate where test cases will be written to, for all subclasses of
this target. Multiple file targets should not reuse a `target_basename`.
@@ -36,7 +36,7 @@
call `.create_test_case()` to yield the TestCase.
Additional details and other attributes/methods are given in the documentation
-of BaseTarget in test_generation.py.
+of BaseTarget in test_data_generation.py.
"""
# Copyright The Mbed TLS Contributors
@@ -63,7 +63,7 @@
import scripts_path # pylint: disable=unused-import
from mbedtls_dev import test_case
-from mbedtls_dev import test_generation
+from mbedtls_dev import test_data_generation
T = TypeVar('T') #pylint: disable=invalid-name
@@ -74,18 +74,16 @@
return "\"{}\"".format(val)
def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
- """Return all pair combinations from input values.
-
- The return value is cast, as older versions of mypy are unable to derive
- the specific type returned by itertools.combinations_with_replacement.
- """
+ """Return all pair combinations from input values."""
+ # The return value is cast, as older versions of mypy are unable to derive
+ # the specific type returned by itertools.combinations_with_replacement.
return typing.cast(
List[Tuple[T, T]],
list(itertools.combinations_with_replacement(values, 2))
)
-class BignumTarget(test_generation.BaseTarget, metaclass=ABCMeta):
+class BignumTarget(test_data_generation.BaseTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method
"""Target for bignum (mpi) test case generation."""
target_basename = 'test_suite_mpi.generated'
@@ -235,4 +233,4 @@
if __name__ == '__main__':
# Use the section of the docstring relevant to the CLI as description
- test_generation.main(sys.argv[1:], "\n".join(__doc__.splitlines()[:4]))
+ test_data_generation.main(sys.argv[1:], "\n".join(__doc__.splitlines()[:4]))
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index 867a6b4..e7d4048 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -30,7 +30,7 @@
from mbedtls_dev import macro_collector
from mbedtls_dev import psa_storage
from mbedtls_dev import test_case
-from mbedtls_dev import test_generation
+from mbedtls_dev import test_data_generation
def psa_want_symbol(name: str) -> str:
@@ -894,7 +894,7 @@
yield from super().generate_all_keys()
yield from self.all_keys_for_implicit_usage()
-class PSATestGenerator(test_generation.TestGenerator):
+class PSATestGenerator(test_data_generation.TestGenerator):
"""Test generator subclass including PSA targets and info."""
# Note that targets whose names contain 'test_format' have their content
# validated by `abi_check.py`.
@@ -919,4 +919,4 @@
super().generate_target(name, self.info)
if __name__ == '__main__':
- test_generation.main(sys.argv[1:], __doc__, PSATestGenerator)
+ test_data_generation.main(sys.argv[1:], __doc__, PSATestGenerator)