Merge pull request #7399 from lpy4105/issue/7014/certificate-audit-script

cert_audit: Add test certificate date audit script
diff --git a/scripts/ci.requirements.txt b/scripts/ci.requirements.txt
index 1ad983f..3ddc417 100644
--- a/scripts/ci.requirements.txt
+++ b/scripts/ci.requirements.txt
@@ -10,3 +10,9 @@
 # Use the earliest version of mypy that works with our code base.
 # See https://github.com/Mbed-TLS/mbedtls/pull/3953 .
 mypy >= 0.780
+
+# Install cryptography to avoid import-error reported by pylint.
+# What we really need is cryptography >= 35.0.0, which is only
+# available for Python >= 3.6.
+cryptography >= 35.0.0; sys_platform == 'linux' and python_version >= '3.6'
+cryptography;           sys_platform == 'linux' and python_version <  '3.6'
diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py
new file mode 100755
index 0000000..1ccfc21
--- /dev/null
+++ b/tests/scripts/audit-validity-dates.py
@@ -0,0 +1,471 @@
+#!/usr/bin/env python3
+#
+# Copyright The Mbed TLS Contributors
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Audit validity date of X509 crt/crl/csr.
+
+This script is used to audit the validity date of crt/crl/csr used for testing.
+It prints the information about X.509 objects excluding the objects that
+are valid throughout the desired validity period. The data are collected
+from tests/data_files/ and tests/suites/*.data files by default.
+"""
+
+import os
+import sys
+import re
+import typing
+import argparse
+import datetime
+import glob
+import logging
+from enum import Enum
+
+# The script requires cryptography >= 35.0.0 which is only available
+# for Python >= 3.6.
+import cryptography
+from cryptography import x509
+
+from generate_test_code import FileWrapper
+
+import scripts_path # pylint: disable=unused-import
+from mbedtls_dev import build_tree
+
+def check_cryptography_version():
+    match = re.match(r'^[0-9]+', cryptography.__version__)
+    if match is None or int(match[0]) < 35:
+        raise Exception("audit-validity-dates requires cryptography >= 35.0.0"
+                        + "({} is too old)".format(cryptography.__version__))
+
+class DataType(Enum):
+    CRT = 1 # Certificate
+    CRL = 2 # Certificate Revocation List
+    CSR = 3 # Certificate Signing Request
+
+
+class DataFormat(Enum):
+    PEM = 1 # Privacy-Enhanced Mail
+    DER = 2 # Distinguished Encoding Rules
+
+
+class AuditData:
+    """Store data location, type and validity period of X.509 objects."""
+    #pylint: disable=too-few-public-methods
+    def __init__(self, data_type: DataType, x509_obj):
+        self.data_type = data_type
+        self.location = ""
+        self.fill_validity_duration(x509_obj)
+
+    def fill_validity_duration(self, x509_obj):
+        """Read validity period from an X.509 object."""
+        # Certificate expires after "not_valid_after"
+        # Certificate is invalid before "not_valid_before"
+        if self.data_type == DataType.CRT:
+            self.not_valid_after = x509_obj.not_valid_after
+            self.not_valid_before = x509_obj.not_valid_before
+        # CertificateRevocationList expires after "next_update"
+        # CertificateRevocationList is invalid before "last_update"
+        elif self.data_type == DataType.CRL:
+            self.not_valid_after = x509_obj.next_update
+            self.not_valid_before = x509_obj.last_update
+        # CertificateSigningRequest is always valid.
+        elif self.data_type == DataType.CSR:
+            self.not_valid_after = datetime.datetime.max
+            self.not_valid_before = datetime.datetime.min
+        else:
+            raise ValueError("Unsupported file_type: {}".format(self.data_type))
+
+
+class X509Parser:
+    """A parser class to parse crt/crl/csr file or data in PEM/DER format."""
+    PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n(?P<data>.*?)-{5}END (?P=type)-{5}\n'
+    PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n'
+    PEM_TAGS = {
+        DataType.CRT: 'CERTIFICATE',
+        DataType.CRL: 'X509 CRL',
+        DataType.CSR: 'CERTIFICATE REQUEST'
+    }
+
+    def __init__(self,
+                 backends:
+                 typing.Dict[DataType,
+                             typing.Dict[DataFormat,
+                                         typing.Callable[[bytes], object]]]) \
+    -> None:
+        self.backends = backends
+        self.__generate_parsers()
+
+    def __generate_parser(self, data_type: DataType):
+        """Parser generator for a specific DataType"""
+        tag = self.PEM_TAGS[data_type]
+        pem_loader = self.backends[data_type][DataFormat.PEM]
+        der_loader = self.backends[data_type][DataFormat.DER]
+        def wrapper(data: bytes):
+            pem_type = X509Parser.pem_data_type(data)
+            # It is in PEM format with target tag
+            if pem_type == tag:
+                return pem_loader(data)
+            # It is in PEM format without target tag
+            if pem_type:
+                return None
+            # It might be in DER format
+            try:
+                result = der_loader(data)
+            except ValueError:
+                result = None
+            return result
+        wrapper.__name__ = "{}.parser[{}]".format(type(self).__name__, tag)
+        return wrapper
+
+    def __generate_parsers(self):
+        """Generate parsers for all support DataType"""
+        self.parsers = {}
+        for data_type, _ in self.PEM_TAGS.items():
+            self.parsers[data_type] = self.__generate_parser(data_type)
+
+    def __getitem__(self, item):
+        return self.parsers[item]
+
+    @staticmethod
+    def pem_data_type(data: bytes) -> typing.Optional[str]:
+        """Get the tag from the data in PEM format
+
+        :param data: data to be checked in binary mode.
+        :return: PEM tag or "" when no tag detected.
+        """
+        m = re.search(X509Parser.PEM_TAG_REGEX, data)
+        if m is not None:
+            return m.group('type').decode('UTF-8')
+        else:
+            return None
+
+    @staticmethod
+    def check_hex_string(hex_str: str) -> bool:
+        """Check if the hex string is possibly DER data."""
+        hex_len = len(hex_str)
+        # At least 6 hex char for 3 bytes: Type + Length + Content
+        if hex_len < 6:
+            return False
+        # Check if Type (1 byte) is SEQUENCE.
+        if hex_str[0:2] != '30':
+            return False
+        # Check LENGTH (1 byte) value
+        content_len = int(hex_str[2:4], base=16)
+        consumed = 4
+        if content_len in (128, 255):
+            # Indefinite or Reserved
+            return False
+        elif content_len > 127:
+            # Definite, Long
+            length_len = (content_len - 128) * 2
+            content_len = int(hex_str[consumed:consumed+length_len], base=16)
+            consumed += length_len
+        # Check LENGTH
+        if hex_len != content_len * 2 + consumed:
+            return False
+        return True
+
+
+class Auditor:
+    """
+    A base class that uses X509Parser to parse files to a list of AuditData.
+
+    A subclass must implement the following methods:
+      - collect_default_files: Return a list of file names that are defaultly
+        used for parsing (auditing). The list will be stored in
+        Auditor.default_files.
+      - parse_file: Method that parses a single file to a list of AuditData.
+
+    A subclass may override the following methods:
+      - parse_bytes: Defaultly, it parses `bytes` that contains only one valid
+        X.509 data(DER/PEM format) to an X.509 object.
+      - walk_all: Defaultly, it iterates over all the files in the provided
+        file name list, calls `parse_file` for each file and stores the results
+        by extending Auditor.audit_data.
+    """
+    def __init__(self, logger):
+        self.logger = logger
+        self.default_files = self.collect_default_files()
+        # A list to store the parsed audit_data.
+        self.audit_data = [] # type: typing.List[AuditData]
+        self.parser = X509Parser({
+            DataType.CRT: {
+                DataFormat.PEM: x509.load_pem_x509_certificate,
+                DataFormat.DER: x509.load_der_x509_certificate
+            },
+            DataType.CRL: {
+                DataFormat.PEM: x509.load_pem_x509_crl,
+                DataFormat.DER: x509.load_der_x509_crl
+            },
+            DataType.CSR: {
+                DataFormat.PEM: x509.load_pem_x509_csr,
+                DataFormat.DER: x509.load_der_x509_csr
+            },
+        })
+
+    def collect_default_files(self) -> typing.List[str]:
+        """Collect the default files for parsing."""
+        raise NotImplementedError
+
+    def parse_file(self, filename: str) -> typing.List[AuditData]:
+        """
+        Parse a list of AuditData from file.
+
+        :param filename: name of the file to parse.
+        :return list of AuditData parsed from the file.
+        """
+        raise NotImplementedError
+
+    def parse_bytes(self, data: bytes):
+        """Parse AuditData from bytes."""
+        for data_type in list(DataType):
+            try:
+                result = self.parser[data_type](data)
+            except ValueError as val_error:
+                result = None
+                self.logger.warning(val_error)
+            if result is not None:
+                audit_data = AuditData(data_type, result)
+                return audit_data
+        return None
+
+    def walk_all(self, file_list: typing.Optional[typing.List[str]] = None):
+        """
+        Iterate over all the files in the list and get audit data.
+        """
+        if file_list is None:
+            file_list = self.default_files
+        for filename in file_list:
+            data_list = self.parse_file(filename)
+            self.audit_data.extend(data_list)
+
+    @staticmethod
+    def find_test_dir():
+        """Get the relative path for the MbedTLS test directory."""
+        return os.path.relpath(build_tree.guess_mbedtls_root() + '/tests')
+
+
+class TestDataAuditor(Auditor):
+    """Class for auditing files in `tests/data_files/`"""
+
+    def collect_default_files(self):
+        """Collect all files in `tests/data_files/`"""
+        test_dir = self.find_test_dir()
+        test_data_glob = os.path.join(test_dir, 'data_files/**')
+        data_files = [f for f in glob.glob(test_data_glob, recursive=True)
+                      if os.path.isfile(f)]
+        return data_files
+
+    def parse_file(self, filename: str) -> typing.List[AuditData]:
+        """
+        Parse a list of AuditData from data file.
+
+        :param filename: name of the file to parse.
+        :return list of AuditData parsed from the file.
+        """
+        with open(filename, 'rb') as f:
+            data = f.read()
+        result = self.parse_bytes(data)
+        if result is not None:
+            result.location = filename
+            return [result]
+        else:
+            return []
+
+
+def parse_suite_data(data_f):
+    """
+    Parses .data file for test arguments that possiblly have a
+    valid X.509 data. If you need a more precise parser, please
+    use generate_test_code.parse_test_data instead.
+
+    :param data_f: file object of the data file.
+    :return: Generator that yields test function argument list.
+    """
+    for line in data_f:
+        line = line.strip()
+        # Skip comments
+        if line.startswith('#'):
+            continue
+
+        # Check parameters line
+        match = re.search(r'\A\w+(.*:)?\"', line)
+        if match:
+            # Read test vectors
+            parts = re.split(r'(?<!\\):', line)
+            parts = [x for x in parts if x]
+            args = parts[1:]
+            yield args
+
+
+class SuiteDataAuditor(Auditor):
+    """Class for auditing files in `tests/suites/*.data`"""
+
+    def collect_default_files(self):
+        """Collect all files in `tests/suites/*.data`"""
+        test_dir = self.find_test_dir()
+        suites_data_folder = os.path.join(test_dir, 'suites')
+        data_files = glob.glob(os.path.join(suites_data_folder, '*.data'))
+        return data_files
+
+    def parse_file(self, filename: str):
+        """
+        Parse a list of AuditData from test suite data file.
+
+        :param filename: name of the file to parse.
+        :return list of AuditData parsed from the file.
+        """
+        audit_data_list = []
+        data_f = FileWrapper(filename)
+        for test_args in parse_suite_data(data_f):
+            for idx, test_arg in enumerate(test_args):
+                match = re.match(r'"(?P<data>[0-9a-fA-F]+)"', test_arg)
+                if not match:
+                    continue
+                if not X509Parser.check_hex_string(match.group('data')):
+                    continue
+                audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
+                if audit_data is None:
+                    continue
+                audit_data.location = "{}:{}:#{}".format(filename,
+                                                         data_f.line_no,
+                                                         idx + 1)
+                audit_data_list.append(audit_data)
+
+        return audit_data_list
+
+
+def list_all(audit_data: AuditData):
+    print("{}\t{}\t{}\t{}".format(
+        audit_data.not_valid_before.isoformat(timespec='seconds'),
+        audit_data.not_valid_after.isoformat(timespec='seconds'),
+        audit_data.data_type.name,
+        audit_data.location))
+
+
+def configure_logger(logger: logging.Logger) -> None:
+    """
+    Configure the logging.Logger instance so that:
+        - Format is set to "[%(levelname)s]: %(message)s".
+        - loglevel >= WARNING are printed to stderr.
+        - loglevel <  WARNING are printed to stdout.
+    """
+    class MaxLevelFilter(logging.Filter):
+        # pylint: disable=too-few-public-methods
+        def __init__(self, max_level, name=''):
+            super().__init__(name)
+            self.max_level = max_level
+
+        def filter(self, record: logging.LogRecord) -> bool:
+            return record.levelno <= self.max_level
+
+    log_formatter = logging.Formatter("[%(levelname)s]: %(message)s")
+
+    # set loglevel >= WARNING to be printed to stderr
+    stderr_hdlr = logging.StreamHandler(sys.stderr)
+    stderr_hdlr.setLevel(logging.WARNING)
+    stderr_hdlr.setFormatter(log_formatter)
+
+    # set loglevel <= INFO to be printed to stdout
+    stdout_hdlr = logging.StreamHandler(sys.stdout)
+    stdout_hdlr.addFilter(MaxLevelFilter(logging.INFO))
+    stdout_hdlr.setFormatter(log_formatter)
+
+    logger.addHandler(stderr_hdlr)
+    logger.addHandler(stdout_hdlr)
+
+
+def main():
+    """
+    Perform argument parsing.
+    """
+    parser = argparse.ArgumentParser(description=__doc__)
+
+    parser.add_argument('-a', '--all',
+                        action='store_true',
+                        help='list the information of all the files')
+    parser.add_argument('-v', '--verbose',
+                        action='store_true', dest='verbose',
+                        help='show logs')
+    parser.add_argument('--from', dest='start_date',
+                        help=('Start of desired validity period (UTC, YYYY-MM-DD). '
+                              'Default: today'),
+                        metavar='DATE')
+    parser.add_argument('--to', dest='end_date',
+                        help=('End of desired validity period (UTC, YYYY-MM-DD). '
+                              'Default: --from'),
+                        metavar='DATE')
+    parser.add_argument('--data-files', action='append', nargs='*',
+                        help='data files to audit',
+                        metavar='FILE')
+    parser.add_argument('--suite-data-files', action='append', nargs='*',
+                        help='suite data files to audit',
+                        metavar='FILE')
+
+    args = parser.parse_args()
+
+    # start main routine
+    # setup logger
+    logger = logging.getLogger()
+    configure_logger(logger)
+    logger.setLevel(logging.DEBUG if args.verbose else logging.ERROR)
+
+    td_auditor = TestDataAuditor(logger)
+    sd_auditor = SuiteDataAuditor(logger)
+
+    data_files = []
+    suite_data_files = []
+    if args.data_files is None and args.suite_data_files is None:
+        data_files = td_auditor.default_files
+        suite_data_files = sd_auditor.default_files
+    else:
+        if args.data_files is not None:
+            data_files = [x for l in args.data_files for x in l]
+        if args.suite_data_files is not None:
+            suite_data_files = [x for l in args.suite_data_files for x in l]
+
+    # validity period start date
+    if args.start_date:
+        start_date = datetime.datetime.fromisoformat(args.start_date)
+    else:
+        start_date = datetime.datetime.today()
+    # validity period end date
+    if args.end_date:
+        end_date = datetime.datetime.fromisoformat(args.end_date)
+    else:
+        end_date = start_date
+
+    # go through all the files
+    td_auditor.walk_all(data_files)
+    sd_auditor.walk_all(suite_data_files)
+    audit_results = td_auditor.audit_data + sd_auditor.audit_data
+
+    # we filter out the files whose validity duration covers the provided
+    # duration.
+    filter_func = lambda d: (start_date < d.not_valid_before) or \
+                            (d.not_valid_after < end_date)
+
+    if args.all:
+        filter_func = None
+
+    # filter and output the results
+    for d in filter(filter_func, audit_results):
+        list_all(d)
+
+    logger.debug("Done!")
+
+check_cryptography_version()
+if __name__ == "__main__":
+    main()
diff --git a/tests/scripts/generate_test_code.py b/tests/scripts/generate_test_code.py
index 839fccd..ff7f9b9 100755
--- a/tests/scripts/generate_test_code.py
+++ b/tests/scripts/generate_test_code.py
@@ -163,7 +163,6 @@
 """
 
 
-import io
 import os
 import re
 import sys
@@ -227,43 +226,57 @@
     pass
 
 
-class FileWrapper(io.FileIO):
+class FileWrapper:
     """
-    This class extends built-in io.FileIO class with attribute line_no,
+    This class extends the file object with attribute line_no,
     that indicates line number for the line that is read.
     """
 
-    def __init__(self, file_name):
+    def __init__(self, file_name) -> None:
         """
-        Instantiate the base class and initialize the line number to 0.
+        Instantiate the file object and initialize the line number to 0.
 
         :param file_name: File path to open.
         """
-        super().__init__(file_name, 'r')
+        # private mix-in file object
+        self._f = open(file_name, 'rb')
         self._line_no = 0
 
+    def __iter__(self):
+        return self
+
     def __next__(self):
         """
-        This method overrides base class's __next__ method and extends it
-        method to count the line numbers as each line is read.
+        This method makes FileWrapper iterable.
+        It counts the line numbers as each line is read.
 
         :return: Line read from file.
         """
-        line = super().__next__()
-        if line is not None:
-            self._line_no += 1
-            # Convert byte array to string with correct encoding and
-            # strip any whitespaces added in the decoding process.
-            return line.decode(sys.getdefaultencoding()).rstrip() + '\n'
-        return None
+        line = self._f.__next__()
+        self._line_no += 1
+        # Convert byte array to string with correct encoding and
+        # strip any whitespaces added in the decoding process.
+        return line.decode(sys.getdefaultencoding()).rstrip()+ '\n'
 
-    def get_line_no(self):
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self._f.__exit__(exc_type, exc_val, exc_tb)
+
+    @property
+    def line_no(self):
         """
-        Gives current line number.
+        Property that indicates line number for the line that is read.
         """
         return self._line_no
 
-    line_no = property(get_line_no)
+    @property
+    def name(self):
+        """
+        Property that indicates name of the file that is read.
+        """
+        return self._f.name
 
 
 def split_dep(dep):