cert_audit: Merge audit_data for identical X.509 objects

Signed-off-by: Pengyu Lv <pengyu.lv@arm.com>
This commit is contained in:
Pengyu Lv 2023-04-28 10:58:38 +08:00
parent e245c0c734
commit fe13bd3d0e

View file

@ -65,8 +65,13 @@ class AuditData:
#pylint: disable=too-few-public-methods
def __init__(self, data_type: DataType, x509_obj):
self.data_type = data_type
self.location = ""
# the locations that the x509 object could be found
self.locations = [] # type: typing.List[str]
self.fill_validity_duration(x509_obj)
self._obj = x509_obj
def __eq__(self, __value) -> bool:
return self._obj == __value._obj
def fill_validity_duration(self, x509_obj):
"""Read validity period from an X.509 object."""
@ -282,7 +287,7 @@ class TestDataAuditor(Auditor):
for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1):
result = self.parse_bytes(data[m.start():m.end()])
if result is not None:
result.location = "{}#{}".format(filename, idx)
result.locations.append("{}#{}".format(filename, idx))
results.append(result)
return results
@ -342,20 +347,38 @@ class SuiteDataAuditor(Auditor):
audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
if audit_data is None:
continue
audit_data.location = "{}:{}:#{}".format(filename,
data_f.line_no,
idx + 1)
audit_data.locations.append("{}:{}:#{}".format(filename,
data_f.line_no,
idx + 1))
audit_data_list.append(audit_data)
return audit_data_list
def merge_auditdata(original: typing.List[AuditData]) \
-> typing.List[AuditData]:
"""
Multiple AuditData might be extracted from different locations for
an identical X.509 object. Merge them into one entry in the list.
"""
results = []
for x in original:
if x not in results:
results.append(x)
else:
idx = results.index(x)
results[idx].locations.extend(x.locations)
return results
def list_all(audit_data: AuditData):
print("{}\t{}\t{}\t{}".format(
print("{:20}\t{:20}\t{:3}\t{}".format(
audit_data.not_valid_before.isoformat(timespec='seconds'),
audit_data.not_valid_after.isoformat(timespec='seconds'),
audit_data.data_type.name,
audit_data.location))
audit_data.locations[0]))
for loc in audit_data.locations[1:]:
print("{:20}\t{:20}\t{:3}\t{}".format('', '', '', loc))
def configure_logger(logger: logging.Logger) -> None:
@ -455,6 +478,10 @@ def main():
sd_auditor.walk_all(suite_data_files)
audit_results = td_auditor.audit_data + sd_auditor.audit_data
audit_results = merge_auditdata(audit_results)
logger.info("Total: {} objects found!".format(len(audit_results)))
# we filter out the files whose validity duration covers the provided
# duration.
filter_func = lambda d: (start_date < d.not_valid_before) or \