Merge pull request #7512 from lpy4105/issue/7014/cert_audit-improvement
cert_audit: Improvements of audit script
diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py
index 1ccfc21..5506e40 100755
--- a/tests/scripts/audit-validity-dates.py
+++ b/tests/scripts/audit-validity-dates.py
@@ -31,6 +31,7 @@
import datetime
import glob
import logging
+import hashlib
from enum import Enum
# The script requires cryptography >= 35.0.0 which is only available
@@ -45,7 +46,7 @@
def check_cryptography_version():
match = re.match(r'^[0-9]+', cryptography.__version__)
- if match is None or int(match[0]) < 35:
+ if match is None or int(match.group(0)) < 35:
raise Exception("audit-validity-dates requires cryptography >= 35.0.0"
+ "({} is too old)".format(cryptography.__version__))
@@ -65,8 +66,20 @@
#pylint: disable=too-few-public-methods
def __init__(self, data_type: DataType, x509_obj):
self.data_type = data_type
- self.location = ""
+ # the locations that the x509 object could be found
+ self.locations = [] # type: typing.List[str]
self.fill_validity_duration(x509_obj)
+ self._obj = x509_obj
+ encoding = cryptography.hazmat.primitives.serialization.Encoding.DER
+ self._identifier = hashlib.sha1(self._obj.public_bytes(encoding)).hexdigest()
+
+ @property
+ def identifier(self):
+ """
+ Identifier of the underlying X.509 object, which is consistent across
+ different runs.
+ """
+ return self._identifier
def fill_validity_duration(self, x509_obj):
"""Read validity period from an X.509 object."""
@@ -90,7 +103,7 @@
class X509Parser:
"""A parser class to parse crt/crl/csr file or data in PEM/DER format."""
- PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n(?P<data>.*?)-{5}END (?P=type)-{5}\n'
+ PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}(?P<data>.*?)-{5}END (?P=type)-{5}'
PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n'
PEM_TAGS = {
DataType.CRT: 'CERTIFICATE',
@@ -193,13 +206,11 @@
X.509 data(DER/PEM format) to an X.509 object.
- walk_all: Defaultly, it iterates over all the files in the provided
file name list, calls `parse_file` for each file and stores the results
- by extending Auditor.audit_data.
+ by extending the `results` passed to the function.
"""
def __init__(self, logger):
self.logger = logger
self.default_files = self.collect_default_files()
- # A list to store the parsed audit_data.
- self.audit_data = [] # type: typing.List[AuditData]
self.parser = X509Parser({
DataType.CRT: {
DataFormat.PEM: x509.load_pem_x509_certificate,
@@ -241,15 +252,27 @@
return audit_data
return None
- def walk_all(self, file_list: typing.Optional[typing.List[str]] = None):
+ def walk_all(self,
+ results: typing.Dict[str, AuditData],
+ file_list: typing.Optional[typing.List[str]] = None) \
+ -> None:
"""
- Iterate over all the files in the list and get audit data.
+ Iterate over all the files in the list and get audit data. The
+ results will be written to `results` passed to this function.
+
+ :param results: The dictionary used to store the parsed
+ AuditData. The keys of this dictionary should
+ be the identifier of the AuditData.
"""
if file_list is None:
file_list = self.default_files
for filename in file_list:
data_list = self.parse_file(filename)
- self.audit_data.extend(data_list)
+ for d in data_list:
+ if d.identifier in results:
+ results[d.identifier].locations.extend(d.locations)
+ else:
+ results[d.identifier] = d
@staticmethod
def find_test_dir():
@@ -277,12 +300,25 @@
"""
with open(filename, 'rb') as f:
data = f.read()
- result = self.parse_bytes(data)
- if result is not None:
- result.location = filename
- return [result]
- else:
- return []
+
+ results = []
+ # Try to parse all PEM blocks.
+ is_pem = False
+ for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1):
+ is_pem = True
+ result = self.parse_bytes(data[m.start():m.end()])
+ if result is not None:
+ result.locations.append("{}#{}".format(filename, idx))
+ results.append(result)
+
+ # Might be DER format.
+ if not is_pem:
+ result = self.parse_bytes(data)
+ if result is not None:
+ result.locations.append("{}".format(filename))
+ results.append(result)
+
+ return results
def parse_suite_data(data_f):
@@ -339,20 +375,22 @@
audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
if audit_data is None:
continue
- audit_data.location = "{}:{}:#{}".format(filename,
- data_f.line_no,
- idx + 1)
+ audit_data.locations.append("{}:{}:#{}".format(filename,
+ data_f.line_no,
+ idx + 1))
audit_data_list.append(audit_data)
return audit_data_list
def list_all(audit_data: AuditData):
- print("{}\t{}\t{}\t{}".format(
- audit_data.not_valid_before.isoformat(timespec='seconds'),
- audit_data.not_valid_after.isoformat(timespec='seconds'),
- audit_data.data_type.name,
- audit_data.location))
+ for loc in audit_data.locations:
+ print("{}\t{:20}\t{:20}\t{:3}\t{}".format(
+ audit_data.identifier,
+ audit_data.not_valid_before.isoformat(timespec='seconds'),
+ audit_data.not_valid_after.isoformat(timespec='seconds'),
+ audit_data.data_type.name,
+ loc))
def configure_logger(logger: logging.Logger) -> None:
@@ -448,20 +486,24 @@
end_date = start_date
# go through all the files
- td_auditor.walk_all(data_files)
- sd_auditor.walk_all(suite_data_files)
- audit_results = td_auditor.audit_data + sd_auditor.audit_data
+ audit_results = {}
+ td_auditor.walk_all(audit_results, data_files)
+ sd_auditor.walk_all(audit_results, suite_data_files)
+
+ logger.info("Total: {} objects found!".format(len(audit_results)))
# we filter out the files whose validity duration covers the provided
# duration.
filter_func = lambda d: (start_date < d.not_valid_before) or \
(d.not_valid_after < end_date)
+ sortby_end = lambda d: d.not_valid_after
+
if args.all:
filter_func = None
# filter and output the results
- for d in filter(filter_func, audit_results):
+ for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end):
list_all(d)
logger.debug("Done!")