Merge pull request #7399 from lpy4105/issue/7014/certificate-audit-script

cert_audit: Add test certificate date audit script
This commit is contained in:
Bence Szépkúti 2023-05-09 13:10:01 +02:00 committed by GitHub
commit ddfd0a27df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 508 additions and 18 deletions

View file

@ -10,3 +10,9 @@ pylint == 2.4.4
# Use the earliest version of mypy that works with our code base.
# See https://github.com/Mbed-TLS/mbedtls/pull/3953 .
mypy >= 0.780
# Install cryptography to avoid import-error reported by pylint.
# What we really need is cryptography >= 35.0.0, which is only
# available for Python >= 3.6.
cryptography >= 35.0.0; sys_platform == 'linux' and python_version >= '3.6'
cryptography; sys_platform == 'linux' and python_version < '3.6'

View file

@ -0,0 +1,471 @@
#!/usr/bin/env python3
#
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Audit validity date of X509 crt/crl/csr.
This script is used to audit the validity date of crt/crl/csr used for testing.
It prints the information about X.509 objects excluding the objects that
are valid throughout the desired validity period. The data are collected
from tests/data_files/ and tests/suites/*.data files by default.
"""
import os
import sys
import re
import typing
import argparse
import datetime
import glob
import logging
from enum import Enum
# The script requires cryptography >= 35.0.0 which is only available
# for Python >= 3.6.
import cryptography
from cryptography import x509
from generate_test_code import FileWrapper
import scripts_path # pylint: disable=unused-import
from mbedtls_dev import build_tree
def check_cryptography_version():
match = re.match(r'^[0-9]+', cryptography.__version__)
if match is None or int(match[0]) < 35:
raise Exception("audit-validity-dates requires cryptography >= 35.0.0"
+ "({} is too old)".format(cryptography.__version__))
class DataType(Enum):
CRT = 1 # Certificate
CRL = 2 # Certificate Revocation List
CSR = 3 # Certificate Signing Request
class DataFormat(Enum):
PEM = 1 # Privacy-Enhanced Mail
DER = 2 # Distinguished Encoding Rules
class AuditData:
"""Store data location, type and validity period of X.509 objects."""
#pylint: disable=too-few-public-methods
def __init__(self, data_type: DataType, x509_obj):
self.data_type = data_type
self.location = ""
self.fill_validity_duration(x509_obj)
def fill_validity_duration(self, x509_obj):
"""Read validity period from an X.509 object."""
# Certificate expires after "not_valid_after"
# Certificate is invalid before "not_valid_before"
if self.data_type == DataType.CRT:
self.not_valid_after = x509_obj.not_valid_after
self.not_valid_before = x509_obj.not_valid_before
# CertificateRevocationList expires after "next_update"
# CertificateRevocationList is invalid before "last_update"
elif self.data_type == DataType.CRL:
self.not_valid_after = x509_obj.next_update
self.not_valid_before = x509_obj.last_update
# CertificateSigningRequest is always valid.
elif self.data_type == DataType.CSR:
self.not_valid_after = datetime.datetime.max
self.not_valid_before = datetime.datetime.min
else:
raise ValueError("Unsupported file_type: {}".format(self.data_type))
class X509Parser:
"""A parser class to parse crt/crl/csr file or data in PEM/DER format."""
PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n(?P<data>.*?)-{5}END (?P=type)-{5}\n'
PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n'
PEM_TAGS = {
DataType.CRT: 'CERTIFICATE',
DataType.CRL: 'X509 CRL',
DataType.CSR: 'CERTIFICATE REQUEST'
}
def __init__(self,
backends:
typing.Dict[DataType,
typing.Dict[DataFormat,
typing.Callable[[bytes], object]]]) \
-> None:
self.backends = backends
self.__generate_parsers()
def __generate_parser(self, data_type: DataType):
"""Parser generator for a specific DataType"""
tag = self.PEM_TAGS[data_type]
pem_loader = self.backends[data_type][DataFormat.PEM]
der_loader = self.backends[data_type][DataFormat.DER]
def wrapper(data: bytes):
pem_type = X509Parser.pem_data_type(data)
# It is in PEM format with target tag
if pem_type == tag:
return pem_loader(data)
# It is in PEM format without target tag
if pem_type:
return None
# It might be in DER format
try:
result = der_loader(data)
except ValueError:
result = None
return result
wrapper.__name__ = "{}.parser[{}]".format(type(self).__name__, tag)
return wrapper
def __generate_parsers(self):
"""Generate parsers for all support DataType"""
self.parsers = {}
for data_type, _ in self.PEM_TAGS.items():
self.parsers[data_type] = self.__generate_parser(data_type)
def __getitem__(self, item):
return self.parsers[item]
@staticmethod
def pem_data_type(data: bytes) -> typing.Optional[str]:
"""Get the tag from the data in PEM format
:param data: data to be checked in binary mode.
:return: PEM tag or "" when no tag detected.
"""
m = re.search(X509Parser.PEM_TAG_REGEX, data)
if m is not None:
return m.group('type').decode('UTF-8')
else:
return None
@staticmethod
def check_hex_string(hex_str: str) -> bool:
"""Check if the hex string is possibly DER data."""
hex_len = len(hex_str)
# At least 6 hex char for 3 bytes: Type + Length + Content
if hex_len < 6:
return False
# Check if Type (1 byte) is SEQUENCE.
if hex_str[0:2] != '30':
return False
# Check LENGTH (1 byte) value
content_len = int(hex_str[2:4], base=16)
consumed = 4
if content_len in (128, 255):
# Indefinite or Reserved
return False
elif content_len > 127:
# Definite, Long
length_len = (content_len - 128) * 2
content_len = int(hex_str[consumed:consumed+length_len], base=16)
consumed += length_len
# Check LENGTH
if hex_len != content_len * 2 + consumed:
return False
return True
class Auditor:
"""
A base class that uses X509Parser to parse files to a list of AuditData.
A subclass must implement the following methods:
- collect_default_files: Return a list of file names that are defaultly
used for parsing (auditing). The list will be stored in
Auditor.default_files.
- parse_file: Method that parses a single file to a list of AuditData.
A subclass may override the following methods:
- parse_bytes: Defaultly, it parses `bytes` that contains only one valid
X.509 data(DER/PEM format) to an X.509 object.
- walk_all: Defaultly, it iterates over all the files in the provided
file name list, calls `parse_file` for each file and stores the results
by extending Auditor.audit_data.
"""
def __init__(self, logger):
self.logger = logger
self.default_files = self.collect_default_files()
# A list to store the parsed audit_data.
self.audit_data = [] # type: typing.List[AuditData]
self.parser = X509Parser({
DataType.CRT: {
DataFormat.PEM: x509.load_pem_x509_certificate,
DataFormat.DER: x509.load_der_x509_certificate
},
DataType.CRL: {
DataFormat.PEM: x509.load_pem_x509_crl,
DataFormat.DER: x509.load_der_x509_crl
},
DataType.CSR: {
DataFormat.PEM: x509.load_pem_x509_csr,
DataFormat.DER: x509.load_der_x509_csr
},
})
def collect_default_files(self) -> typing.List[str]:
"""Collect the default files for parsing."""
raise NotImplementedError
def parse_file(self, filename: str) -> typing.List[AuditData]:
"""
Parse a list of AuditData from file.
:param filename: name of the file to parse.
:return list of AuditData parsed from the file.
"""
raise NotImplementedError
def parse_bytes(self, data: bytes):
"""Parse AuditData from bytes."""
for data_type in list(DataType):
try:
result = self.parser[data_type](data)
except ValueError as val_error:
result = None
self.logger.warning(val_error)
if result is not None:
audit_data = AuditData(data_type, result)
return audit_data
return None
def walk_all(self, file_list: typing.Optional[typing.List[str]] = None):
"""
Iterate over all the files in the list and get audit data.
"""
if file_list is None:
file_list = self.default_files
for filename in file_list:
data_list = self.parse_file(filename)
self.audit_data.extend(data_list)
@staticmethod
def find_test_dir():
"""Get the relative path for the MbedTLS test directory."""
return os.path.relpath(build_tree.guess_mbedtls_root() + '/tests')
class TestDataAuditor(Auditor):
"""Class for auditing files in `tests/data_files/`"""
def collect_default_files(self):
"""Collect all files in `tests/data_files/`"""
test_dir = self.find_test_dir()
test_data_glob = os.path.join(test_dir, 'data_files/**')
data_files = [f for f in glob.glob(test_data_glob, recursive=True)
if os.path.isfile(f)]
return data_files
def parse_file(self, filename: str) -> typing.List[AuditData]:
"""
Parse a list of AuditData from data file.
:param filename: name of the file to parse.
:return list of AuditData parsed from the file.
"""
with open(filename, 'rb') as f:
data = f.read()
result = self.parse_bytes(data)
if result is not None:
result.location = filename
return [result]
else:
return []
def parse_suite_data(data_f):
"""
Parses .data file for test arguments that possiblly have a
valid X.509 data. If you need a more precise parser, please
use generate_test_code.parse_test_data instead.
:param data_f: file object of the data file.
:return: Generator that yields test function argument list.
"""
for line in data_f:
line = line.strip()
# Skip comments
if line.startswith('#'):
continue
# Check parameters line
match = re.search(r'\A\w+(.*:)?\"', line)
if match:
# Read test vectors
parts = re.split(r'(?<!\\):', line)
parts = [x for x in parts if x]
args = parts[1:]
yield args
class SuiteDataAuditor(Auditor):
"""Class for auditing files in `tests/suites/*.data`"""
def collect_default_files(self):
"""Collect all files in `tests/suites/*.data`"""
test_dir = self.find_test_dir()
suites_data_folder = os.path.join(test_dir, 'suites')
data_files = glob.glob(os.path.join(suites_data_folder, '*.data'))
return data_files
def parse_file(self, filename: str):
"""
Parse a list of AuditData from test suite data file.
:param filename: name of the file to parse.
:return list of AuditData parsed from the file.
"""
audit_data_list = []
data_f = FileWrapper(filename)
for test_args in parse_suite_data(data_f):
for idx, test_arg in enumerate(test_args):
match = re.match(r'"(?P<data>[0-9a-fA-F]+)"', test_arg)
if not match:
continue
if not X509Parser.check_hex_string(match.group('data')):
continue
audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
if audit_data is None:
continue
audit_data.location = "{}:{}:#{}".format(filename,
data_f.line_no,
idx + 1)
audit_data_list.append(audit_data)
return audit_data_list
def list_all(audit_data: AuditData):
print("{}\t{}\t{}\t{}".format(
audit_data.not_valid_before.isoformat(timespec='seconds'),
audit_data.not_valid_after.isoformat(timespec='seconds'),
audit_data.data_type.name,
audit_data.location))
def configure_logger(logger: logging.Logger) -> None:
"""
Configure the logging.Logger instance so that:
- Format is set to "[%(levelname)s]: %(message)s".
- loglevel >= WARNING are printed to stderr.
- loglevel < WARNING are printed to stdout.
"""
class MaxLevelFilter(logging.Filter):
# pylint: disable=too-few-public-methods
def __init__(self, max_level, name=''):
super().__init__(name)
self.max_level = max_level
def filter(self, record: logging.LogRecord) -> bool:
return record.levelno <= self.max_level
log_formatter = logging.Formatter("[%(levelname)s]: %(message)s")
# set loglevel >= WARNING to be printed to stderr
stderr_hdlr = logging.StreamHandler(sys.stderr)
stderr_hdlr.setLevel(logging.WARNING)
stderr_hdlr.setFormatter(log_formatter)
# set loglevel <= INFO to be printed to stdout
stdout_hdlr = logging.StreamHandler(sys.stdout)
stdout_hdlr.addFilter(MaxLevelFilter(logging.INFO))
stdout_hdlr.setFormatter(log_formatter)
logger.addHandler(stderr_hdlr)
logger.addHandler(stdout_hdlr)
def main():
"""
Perform argument parsing.
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('-a', '--all',
action='store_true',
help='list the information of all the files')
parser.add_argument('-v', '--verbose',
action='store_true', dest='verbose',
help='show logs')
parser.add_argument('--from', dest='start_date',
help=('Start of desired validity period (UTC, YYYY-MM-DD). '
'Default: today'),
metavar='DATE')
parser.add_argument('--to', dest='end_date',
help=('End of desired validity period (UTC, YYYY-MM-DD). '
'Default: --from'),
metavar='DATE')
parser.add_argument('--data-files', action='append', nargs='*',
help='data files to audit',
metavar='FILE')
parser.add_argument('--suite-data-files', action='append', nargs='*',
help='suite data files to audit',
metavar='FILE')
args = parser.parse_args()
# start main routine
# setup logger
logger = logging.getLogger()
configure_logger(logger)
logger.setLevel(logging.DEBUG if args.verbose else logging.ERROR)
td_auditor = TestDataAuditor(logger)
sd_auditor = SuiteDataAuditor(logger)
data_files = []
suite_data_files = []
if args.data_files is None and args.suite_data_files is None:
data_files = td_auditor.default_files
suite_data_files = sd_auditor.default_files
else:
if args.data_files is not None:
data_files = [x for l in args.data_files for x in l]
if args.suite_data_files is not None:
suite_data_files = [x for l in args.suite_data_files for x in l]
# validity period start date
if args.start_date:
start_date = datetime.datetime.fromisoformat(args.start_date)
else:
start_date = datetime.datetime.today()
# validity period end date
if args.end_date:
end_date = datetime.datetime.fromisoformat(args.end_date)
else:
end_date = start_date
# go through all the files
td_auditor.walk_all(data_files)
sd_auditor.walk_all(suite_data_files)
audit_results = td_auditor.audit_data + sd_auditor.audit_data
# we filter out the files whose validity duration covers the provided
# duration.
filter_func = lambda d: (start_date < d.not_valid_before) or \
(d.not_valid_after < end_date)
if args.all:
filter_func = None
# filter and output the results
for d in filter(filter_func, audit_results):
list_all(d)
logger.debug("Done!")
check_cryptography_version()
if __name__ == "__main__":
main()

View file

@ -163,7 +163,6 @@ __MBEDTLS_TEST_TEMPLATE__PLATFORM_CODE
"""
import io
import os
import re
import sys
@ -227,43 +226,57 @@ class GeneratorInputError(Exception):
pass
class FileWrapper(io.FileIO):
class FileWrapper:
"""
This class extends built-in io.FileIO class with attribute line_no,
This class extends the file object with attribute line_no,
that indicates line number for the line that is read.
"""
def __init__(self, file_name):
def __init__(self, file_name) -> None:
"""
Instantiate the base class and initialize the line number to 0.
Instantiate the file object and initialize the line number to 0.
:param file_name: File path to open.
"""
super().__init__(file_name, 'r')
# private mix-in file object
self._f = open(file_name, 'rb')
self._line_no = 0
def __iter__(self):
return self
def __next__(self):
"""
This method overrides base class's __next__ method and extends it
method to count the line numbers as each line is read.
This method makes FileWrapper iterable.
It counts the line numbers as each line is read.
:return: Line read from file.
"""
line = super().__next__()
if line is not None:
self._line_no += 1
# Convert byte array to string with correct encoding and
# strip any whitespaces added in the decoding process.
return line.decode(sys.getdefaultencoding()).rstrip() + '\n'
return None
line = self._f.__next__()
self._line_no += 1
# Convert byte array to string with correct encoding and
# strip any whitespaces added in the decoding process.
return line.decode(sys.getdefaultencoding()).rstrip()+ '\n'
def get_line_no(self):
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._f.__exit__(exc_type, exc_val, exc_tb)
@property
def line_no(self):
"""
Gives current line number.
Property that indicates line number for the line that is read.
"""
return self._line_no
line_no = property(get_line_no)
@property
def name(self):
"""
Property that indicates name of the file that is read.
"""
return self._f.name
def split_dep(dep):