Merge pull request #7512 from lpy4105/issue/7014/cert_audit-improvement

cert_audit: Improvements of audit script
diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py
index 1ccfc21..5506e40 100755
--- a/tests/scripts/audit-validity-dates.py
+++ b/tests/scripts/audit-validity-dates.py
@@ -31,6 +31,7 @@
 import datetime
 import glob
 import logging
+import hashlib
 from enum import Enum
 
 # The script requires cryptography >= 35.0.0 which is only available
@@ -45,7 +46,7 @@
 
 def check_cryptography_version():
     match = re.match(r'^[0-9]+', cryptography.__version__)
-    if match is None or int(match[0]) < 35:
+    if match is None or int(match.group(0)) < 35:
         raise Exception("audit-validity-dates requires cryptography >= 35.0.0"
                         + "({} is too old)".format(cryptography.__version__))
 
@@ -65,8 +66,20 @@
     #pylint: disable=too-few-public-methods
     def __init__(self, data_type: DataType, x509_obj):
         self.data_type = data_type
-        self.location = ""
+        # the locations that the x509 object could be found
+        self.locations = [] # type: typing.List[str]
         self.fill_validity_duration(x509_obj)
+        self._obj = x509_obj
+        encoding = cryptography.hazmat.primitives.serialization.Encoding.DER
+        self._identifier = hashlib.sha1(self._obj.public_bytes(encoding)).hexdigest()
+
+    @property
+    def identifier(self):
+        """
+        Identifier of the underlying X.509 object, which is consistent across
+        different runs.
+        """
+        return self._identifier
 
     def fill_validity_duration(self, x509_obj):
         """Read validity period from an X.509 object."""
@@ -90,7 +103,7 @@
 
 class X509Parser:
     """A parser class to parse crt/crl/csr file or data in PEM/DER format."""
-    PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n(?P<data>.*?)-{5}END (?P=type)-{5}\n'
+    PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}(?P<data>.*?)-{5}END (?P=type)-{5}'
     PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n'
     PEM_TAGS = {
         DataType.CRT: 'CERTIFICATE',
@@ -193,13 +206,11 @@
         X.509 data(DER/PEM format) to an X.509 object.
       - walk_all: Defaultly, it iterates over all the files in the provided
         file name list, calls `parse_file` for each file and stores the results
-        by extending Auditor.audit_data.
+        by extending the `results` passed to the function.
     """
     def __init__(self, logger):
         self.logger = logger
         self.default_files = self.collect_default_files()
-        # A list to store the parsed audit_data.
-        self.audit_data = [] # type: typing.List[AuditData]
         self.parser = X509Parser({
             DataType.CRT: {
                 DataFormat.PEM: x509.load_pem_x509_certificate,
@@ -241,15 +252,27 @@
                 return audit_data
         return None
 
-    def walk_all(self, file_list: typing.Optional[typing.List[str]] = None):
+    def walk_all(self,
+                 results: typing.Dict[str, AuditData],
+                 file_list: typing.Optional[typing.List[str]] = None) \
+        -> None:
         """
-        Iterate over all the files in the list and get audit data.
+        Iterate over all the files in the list and get audit data. The
+        results will be written to `results` passed to this function.
+
+        :param results: The dictionary used to store the parsed
+                        AuditData. The keys of this dictionary should
+                        be the identifier of the AuditData.
         """
         if file_list is None:
             file_list = self.default_files
         for filename in file_list:
             data_list = self.parse_file(filename)
-            self.audit_data.extend(data_list)
+            for d in data_list:
+                if d.identifier in results:
+                    results[d.identifier].locations.extend(d.locations)
+                else:
+                    results[d.identifier] = d
 
     @staticmethod
     def find_test_dir():
@@ -277,12 +300,25 @@
         """
         with open(filename, 'rb') as f:
             data = f.read()
-        result = self.parse_bytes(data)
-        if result is not None:
-            result.location = filename
-            return [result]
-        else:
-            return []
+
+        results = []
+        # Try to parse all PEM blocks.
+        is_pem = False
+        for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1):
+            is_pem = True
+            result = self.parse_bytes(data[m.start():m.end()])
+            if result is not None:
+                result.locations.append("{}#{}".format(filename, idx))
+                results.append(result)
+
+        # Might be DER format.
+        if not is_pem:
+            result = self.parse_bytes(data)
+            if result is not None:
+                result.locations.append("{}".format(filename))
+                results.append(result)
+
+        return results
 
 
 def parse_suite_data(data_f):
@@ -339,20 +375,22 @@
                 audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
                 if audit_data is None:
                     continue
-                audit_data.location = "{}:{}:#{}".format(filename,
-                                                         data_f.line_no,
-                                                         idx + 1)
+                audit_data.locations.append("{}:{}:#{}".format(filename,
+                                                               data_f.line_no,
+                                                               idx + 1))
                 audit_data_list.append(audit_data)
 
         return audit_data_list
 
 
 def list_all(audit_data: AuditData):
-    print("{}\t{}\t{}\t{}".format(
-        audit_data.not_valid_before.isoformat(timespec='seconds'),
-        audit_data.not_valid_after.isoformat(timespec='seconds'),
-        audit_data.data_type.name,
-        audit_data.location))
+    for loc in audit_data.locations:
+        print("{}\t{:20}\t{:20}\t{:3}\t{}".format(
+            audit_data.identifier,
+            audit_data.not_valid_before.isoformat(timespec='seconds'),
+            audit_data.not_valid_after.isoformat(timespec='seconds'),
+            audit_data.data_type.name,
+            loc))
 
 
 def configure_logger(logger: logging.Logger) -> None:
@@ -448,20 +486,24 @@
         end_date = start_date
 
     # go through all the files
-    td_auditor.walk_all(data_files)
-    sd_auditor.walk_all(suite_data_files)
-    audit_results = td_auditor.audit_data + sd_auditor.audit_data
+    audit_results = {}
+    td_auditor.walk_all(audit_results, data_files)
+    sd_auditor.walk_all(audit_results, suite_data_files)
+
+    logger.info("Total: {} objects found!".format(len(audit_results)))
 
     # we filter out the files whose validity duration covers the provided
     # duration.
     filter_func = lambda d: (start_date < d.not_valid_before) or \
                             (d.not_valid_after < end_date)
 
+    sortby_end = lambda d: d.not_valid_after
+
     if args.all:
         filter_func = None
 
     # filter and output the results
-    for d in filter(filter_func, audit_results):
+    for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end):
         list_all(d)
 
     logger.debug("Done!")