diff --git a/scripts/ci.requirements.txt b/scripts/ci.requirements.txt index 1ad983fa9..3ddc41705 100644 --- a/scripts/ci.requirements.txt +++ b/scripts/ci.requirements.txt @@ -10,3 +10,9 @@ pylint == 2.4.4 # Use the earliest version of mypy that works with our code base. # See https://github.com/Mbed-TLS/mbedtls/pull/3953 . mypy >= 0.780 + +# Install cryptography to avoid import-error reported by pylint. +# What we really need is cryptography >= 35.0.0, which is only +# available for Python >= 3.6. +cryptography >= 35.0.0; sys_platform == 'linux' and python_version >= '3.6' +cryptography; sys_platform == 'linux' and python_version < '3.6' diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py new file mode 100755 index 000000000..1ccfc2188 --- /dev/null +++ b/tests/scripts/audit-validity-dates.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +# +# Copyright The Mbed TLS Contributors +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audit validity date of X509 crt/crl/csr. + +This script is used to audit the validity date of crt/crl/csr used for testing. +It prints the information about X.509 objects excluding the objects that +are valid throughout the desired validity period. The data are collected +from tests/data_files/ and tests/suites/*.data files by default. +""" + +import os +import sys +import re +import typing +import argparse +import datetime +import glob +import logging +from enum import Enum + +# The script requires cryptography >= 35.0.0 which is only available +# for Python >= 3.6. +import cryptography +from cryptography import x509 + +from generate_test_code import FileWrapper + +import scripts_path # pylint: disable=unused-import +from mbedtls_dev import build_tree + +def check_cryptography_version(): + match = re.match(r'^[0-9]+', cryptography.__version__) + if match is None or int(match[0]) < 35: + raise Exception("audit-validity-dates requires cryptography >= 35.0.0" + + "({} is too old)".format(cryptography.__version__)) + +class DataType(Enum): + CRT = 1 # Certificate + CRL = 2 # Certificate Revocation List + CSR = 3 # Certificate Signing Request + + +class DataFormat(Enum): + PEM = 1 # Privacy-Enhanced Mail + DER = 2 # Distinguished Encoding Rules + + +class AuditData: + """Store data location, type and validity period of X.509 objects.""" + #pylint: disable=too-few-public-methods + def __init__(self, data_type: DataType, x509_obj): + self.data_type = data_type + self.location = "" + self.fill_validity_duration(x509_obj) + + def fill_validity_duration(self, x509_obj): + """Read validity period from an X.509 object.""" + # Certificate expires after "not_valid_after" + # Certificate is invalid before "not_valid_before" + if self.data_type == DataType.CRT: + self.not_valid_after = x509_obj.not_valid_after + self.not_valid_before = x509_obj.not_valid_before + # CertificateRevocationList expires after "next_update" + # CertificateRevocationList is invalid before "last_update" + elif self.data_type == DataType.CRL: + self.not_valid_after = x509_obj.next_update + self.not_valid_before = x509_obj.last_update + # CertificateSigningRequest is always valid. + elif self.data_type == DataType.CSR: + self.not_valid_after = datetime.datetime.max + self.not_valid_before = datetime.datetime.min + else: + raise ValueError("Unsupported file_type: {}".format(self.data_type)) + + +class X509Parser: + """A parser class to parse crt/crl/csr file or data in PEM/DER format.""" + PEM_REGEX = br'-{5}BEGIN (?P.*?)-{5}\n(?P.*?)-{5}END (?P=type)-{5}\n' + PEM_TAG_REGEX = br'-{5}BEGIN (?P.*?)-{5}\n' + PEM_TAGS = { + DataType.CRT: 'CERTIFICATE', + DataType.CRL: 'X509 CRL', + DataType.CSR: 'CERTIFICATE REQUEST' + } + + def __init__(self, + backends: + typing.Dict[DataType, + typing.Dict[DataFormat, + typing.Callable[[bytes], object]]]) \ + -> None: + self.backends = backends + self.__generate_parsers() + + def __generate_parser(self, data_type: DataType): + """Parser generator for a specific DataType""" + tag = self.PEM_TAGS[data_type] + pem_loader = self.backends[data_type][DataFormat.PEM] + der_loader = self.backends[data_type][DataFormat.DER] + def wrapper(data: bytes): + pem_type = X509Parser.pem_data_type(data) + # It is in PEM format with target tag + if pem_type == tag: + return pem_loader(data) + # It is in PEM format without target tag + if pem_type: + return None + # It might be in DER format + try: + result = der_loader(data) + except ValueError: + result = None + return result + wrapper.__name__ = "{}.parser[{}]".format(type(self).__name__, tag) + return wrapper + + def __generate_parsers(self): + """Generate parsers for all support DataType""" + self.parsers = {} + for data_type, _ in self.PEM_TAGS.items(): + self.parsers[data_type] = self.__generate_parser(data_type) + + def __getitem__(self, item): + return self.parsers[item] + + @staticmethod + def pem_data_type(data: bytes) -> typing.Optional[str]: + """Get the tag from the data in PEM format + + :param data: data to be checked in binary mode. + :return: PEM tag or "" when no tag detected. + """ + m = re.search(X509Parser.PEM_TAG_REGEX, data) + if m is not None: + return m.group('type').decode('UTF-8') + else: + return None + + @staticmethod + def check_hex_string(hex_str: str) -> bool: + """Check if the hex string is possibly DER data.""" + hex_len = len(hex_str) + # At least 6 hex char for 3 bytes: Type + Length + Content + if hex_len < 6: + return False + # Check if Type (1 byte) is SEQUENCE. + if hex_str[0:2] != '30': + return False + # Check LENGTH (1 byte) value + content_len = int(hex_str[2:4], base=16) + consumed = 4 + if content_len in (128, 255): + # Indefinite or Reserved + return False + elif content_len > 127: + # Definite, Long + length_len = (content_len - 128) * 2 + content_len = int(hex_str[consumed:consumed+length_len], base=16) + consumed += length_len + # Check LENGTH + if hex_len != content_len * 2 + consumed: + return False + return True + + +class Auditor: + """ + A base class that uses X509Parser to parse files to a list of AuditData. + + A subclass must implement the following methods: + - collect_default_files: Return a list of file names that are defaultly + used for parsing (auditing). The list will be stored in + Auditor.default_files. + - parse_file: Method that parses a single file to a list of AuditData. + + A subclass may override the following methods: + - parse_bytes: Defaultly, it parses `bytes` that contains only one valid + X.509 data(DER/PEM format) to an X.509 object. + - walk_all: Defaultly, it iterates over all the files in the provided + file name list, calls `parse_file` for each file and stores the results + by extending Auditor.audit_data. + """ + def __init__(self, logger): + self.logger = logger + self.default_files = self.collect_default_files() + # A list to store the parsed audit_data. + self.audit_data = [] # type: typing.List[AuditData] + self.parser = X509Parser({ + DataType.CRT: { + DataFormat.PEM: x509.load_pem_x509_certificate, + DataFormat.DER: x509.load_der_x509_certificate + }, + DataType.CRL: { + DataFormat.PEM: x509.load_pem_x509_crl, + DataFormat.DER: x509.load_der_x509_crl + }, + DataType.CSR: { + DataFormat.PEM: x509.load_pem_x509_csr, + DataFormat.DER: x509.load_der_x509_csr + }, + }) + + def collect_default_files(self) -> typing.List[str]: + """Collect the default files for parsing.""" + raise NotImplementedError + + def parse_file(self, filename: str) -> typing.List[AuditData]: + """ + Parse a list of AuditData from file. + + :param filename: name of the file to parse. + :return list of AuditData parsed from the file. + """ + raise NotImplementedError + + def parse_bytes(self, data: bytes): + """Parse AuditData from bytes.""" + for data_type in list(DataType): + try: + result = self.parser[data_type](data) + except ValueError as val_error: + result = None + self.logger.warning(val_error) + if result is not None: + audit_data = AuditData(data_type, result) + return audit_data + return None + + def walk_all(self, file_list: typing.Optional[typing.List[str]] = None): + """ + Iterate over all the files in the list and get audit data. + """ + if file_list is None: + file_list = self.default_files + for filename in file_list: + data_list = self.parse_file(filename) + self.audit_data.extend(data_list) + + @staticmethod + def find_test_dir(): + """Get the relative path for the MbedTLS test directory.""" + return os.path.relpath(build_tree.guess_mbedtls_root() + '/tests') + + +class TestDataAuditor(Auditor): + """Class for auditing files in `tests/data_files/`""" + + def collect_default_files(self): + """Collect all files in `tests/data_files/`""" + test_dir = self.find_test_dir() + test_data_glob = os.path.join(test_dir, 'data_files/**') + data_files = [f for f in glob.glob(test_data_glob, recursive=True) + if os.path.isfile(f)] + return data_files + + def parse_file(self, filename: str) -> typing.List[AuditData]: + """ + Parse a list of AuditData from data file. + + :param filename: name of the file to parse. + :return list of AuditData parsed from the file. + """ + with open(filename, 'rb') as f: + data = f.read() + result = self.parse_bytes(data) + if result is not None: + result.location = filename + return [result] + else: + return [] + + +def parse_suite_data(data_f): + """ + Parses .data file for test arguments that possiblly have a + valid X.509 data. If you need a more precise parser, please + use generate_test_code.parse_test_data instead. + + :param data_f: file object of the data file. + :return: Generator that yields test function argument list. + """ + for line in data_f: + line = line.strip() + # Skip comments + if line.startswith('#'): + continue + + # Check parameters line + match = re.search(r'\A\w+(.*:)?\"', line) + if match: + # Read test vectors + parts = re.split(r'(?[0-9a-fA-F]+)"', test_arg) + if not match: + continue + if not X509Parser.check_hex_string(match.group('data')): + continue + audit_data = self.parse_bytes(bytes.fromhex(match.group('data'))) + if audit_data is None: + continue + audit_data.location = "{}:{}:#{}".format(filename, + data_f.line_no, + idx + 1) + audit_data_list.append(audit_data) + + return audit_data_list + + +def list_all(audit_data: AuditData): + print("{}\t{}\t{}\t{}".format( + audit_data.not_valid_before.isoformat(timespec='seconds'), + audit_data.not_valid_after.isoformat(timespec='seconds'), + audit_data.data_type.name, + audit_data.location)) + + +def configure_logger(logger: logging.Logger) -> None: + """ + Configure the logging.Logger instance so that: + - Format is set to "[%(levelname)s]: %(message)s". + - loglevel >= WARNING are printed to stderr. + - loglevel < WARNING are printed to stdout. + """ + class MaxLevelFilter(logging.Filter): + # pylint: disable=too-few-public-methods + def __init__(self, max_level, name=''): + super().__init__(name) + self.max_level = max_level + + def filter(self, record: logging.LogRecord) -> bool: + return record.levelno <= self.max_level + + log_formatter = logging.Formatter("[%(levelname)s]: %(message)s") + + # set loglevel >= WARNING to be printed to stderr + stderr_hdlr = logging.StreamHandler(sys.stderr) + stderr_hdlr.setLevel(logging.WARNING) + stderr_hdlr.setFormatter(log_formatter) + + # set loglevel <= INFO to be printed to stdout + stdout_hdlr = logging.StreamHandler(sys.stdout) + stdout_hdlr.addFilter(MaxLevelFilter(logging.INFO)) + stdout_hdlr.setFormatter(log_formatter) + + logger.addHandler(stderr_hdlr) + logger.addHandler(stdout_hdlr) + + +def main(): + """ + Perform argument parsing. + """ + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument('-a', '--all', + action='store_true', + help='list the information of all the files') + parser.add_argument('-v', '--verbose', + action='store_true', dest='verbose', + help='show logs') + parser.add_argument('--from', dest='start_date', + help=('Start of desired validity period (UTC, YYYY-MM-DD). ' + 'Default: today'), + metavar='DATE') + parser.add_argument('--to', dest='end_date', + help=('End of desired validity period (UTC, YYYY-MM-DD). ' + 'Default: --from'), + metavar='DATE') + parser.add_argument('--data-files', action='append', nargs='*', + help='data files to audit', + metavar='FILE') + parser.add_argument('--suite-data-files', action='append', nargs='*', + help='suite data files to audit', + metavar='FILE') + + args = parser.parse_args() + + # start main routine + # setup logger + logger = logging.getLogger() + configure_logger(logger) + logger.setLevel(logging.DEBUG if args.verbose else logging.ERROR) + + td_auditor = TestDataAuditor(logger) + sd_auditor = SuiteDataAuditor(logger) + + data_files = [] + suite_data_files = [] + if args.data_files is None and args.suite_data_files is None: + data_files = td_auditor.default_files + suite_data_files = sd_auditor.default_files + else: + if args.data_files is not None: + data_files = [x for l in args.data_files for x in l] + if args.suite_data_files is not None: + suite_data_files = [x for l in args.suite_data_files for x in l] + + # validity period start date + if args.start_date: + start_date = datetime.datetime.fromisoformat(args.start_date) + else: + start_date = datetime.datetime.today() + # validity period end date + if args.end_date: + end_date = datetime.datetime.fromisoformat(args.end_date) + else: + end_date = start_date + + # go through all the files + td_auditor.walk_all(data_files) + sd_auditor.walk_all(suite_data_files) + audit_results = td_auditor.audit_data + sd_auditor.audit_data + + # we filter out the files whose validity duration covers the provided + # duration. + filter_func = lambda d: (start_date < d.not_valid_before) or \ + (d.not_valid_after < end_date) + + if args.all: + filter_func = None + + # filter and output the results + for d in filter(filter_func, audit_results): + list_all(d) + + logger.debug("Done!") + +check_cryptography_version() +if __name__ == "__main__": + main() diff --git a/tests/scripts/generate_test_code.py b/tests/scripts/generate_test_code.py index 839fccddd..ff7f9b997 100755 --- a/tests/scripts/generate_test_code.py +++ b/tests/scripts/generate_test_code.py @@ -163,7 +163,6 @@ __MBEDTLS_TEST_TEMPLATE__PLATFORM_CODE """ -import io import os import re import sys @@ -227,43 +226,57 @@ class GeneratorInputError(Exception): pass -class FileWrapper(io.FileIO): +class FileWrapper: """ - This class extends built-in io.FileIO class with attribute line_no, + This class extends the file object with attribute line_no, that indicates line number for the line that is read. """ - def __init__(self, file_name): + def __init__(self, file_name) -> None: """ - Instantiate the base class and initialize the line number to 0. + Instantiate the file object and initialize the line number to 0. :param file_name: File path to open. """ - super().__init__(file_name, 'r') + # private mix-in file object + self._f = open(file_name, 'rb') self._line_no = 0 + def __iter__(self): + return self + def __next__(self): """ - This method overrides base class's __next__ method and extends it - method to count the line numbers as each line is read. + This method makes FileWrapper iterable. + It counts the line numbers as each line is read. :return: Line read from file. """ - line = super().__next__() - if line is not None: - self._line_no += 1 - # Convert byte array to string with correct encoding and - # strip any whitespaces added in the decoding process. - return line.decode(sys.getdefaultencoding()).rstrip() + '\n' - return None + line = self._f.__next__() + self._line_no += 1 + # Convert byte array to string with correct encoding and + # strip any whitespaces added in the decoding process. + return line.decode(sys.getdefaultencoding()).rstrip()+ '\n' - def get_line_no(self): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._f.__exit__(exc_type, exc_val, exc_tb) + + @property + def line_no(self): """ - Gives current line number. + Property that indicates line number for the line that is read. """ return self._line_no - line_no = property(get_line_no) + @property + def name(self): + """ + Property that indicates name of the file that is read. + """ + return self._f.name def split_dep(dep):