diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py index 724f8d94b..89319870d 100755 --- a/tests/scripts/test_psa_constant_names.py +++ b/tests/scripts/test_psa_constant_names.py @@ -8,6 +8,7 @@ or 1 (with a Python backtrace) if there was an operational error. """ import argparse +from collections import namedtuple import itertools import os import platform @@ -60,12 +61,15 @@ class read_file_lines: from exc_value class Inputs: + # pylint: disable=too-many-instance-attributes """Accumulate information about macros to test. + This includes macro names as well as information about their arguments when applicable. """ def __init__(self): + self.all_declared = set() # Sets of names per type self.statuses = set(['PSA_SUCCESS']) self.algorithms = set(['0xffffffff']) @@ -86,11 +90,30 @@ class Inputs: self.table_by_prefix = { 'ERROR': self.statuses, 'ALG': self.algorithms, - 'CURVE': self.ecc_curves, - 'GROUP': self.dh_groups, + 'ECC_CURVE': self.ecc_curves, + 'DH_GROUP': self.dh_groups, 'KEY_TYPE': self.key_types, 'KEY_USAGE': self.key_usage_flags, } + # Test functions + self.table_by_test_function = { + # Any function ending in _algorithm also gets added to + # self.algorithms. + 'key_type': [self.key_types], + 'ecc_key_types': [self.ecc_curves], + 'dh_key_types': [self.dh_groups], + 'hash_algorithm': [self.hash_algorithms], + 'mac_algorithm': [self.mac_algorithms], + 'cipher_algorithm': [], + 'hmac_algorithm': [self.mac_algorithms], + 'aead_algorithm': [self.aead_algorithms], + 'key_derivation_algorithm': [self.kdf_algorithms], + 'key_agreement_algorithm': [self.ka_algorithms], + 'asymmetric_signature_algorithm': [], + 'asymmetric_signature_wildcard': [self.algorithms], + 'asymmetric_encryption_algorithm': [], + 'other_algorithm': [], + } # macro name -> list of argument names self.argspecs = {} # argument name -> list of values @@ -99,8 +122,20 @@ class Inputs: 'tag_length': ['1', '63'], } + def get_names(self, type_word): + """Return the set of known names of values of the given type.""" + return { + 'status': self.statuses, + 'algorithm': self.algorithms, + 'ecc_curve': self.ecc_curves, + 'dh_group': self.dh_groups, + 'key_type': self.key_types, + 'key_usage': self.key_usage_flags, + }[type_word] + def gather_arguments(self): """Populate the list of values for macro arguments. + Call this after parsing all the inputs. """ self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) @@ -118,6 +153,7 @@ class Inputs: def distribute_arguments(self, name): """Generate macro calls with each tested argument set. + If name is a macro without arguments, just yield "name". If name is a macro with arguments, yield a series of "name(arg1,...,argN)" where each argument takes each possible @@ -145,6 +181,9 @@ class Inputs: except BaseException as e: raise Exception('distribute_arguments({})'.format(name)) from e + def generate_expressions(self, names): + return itertools.chain(*map(self.distribute_arguments, names)) + _argument_split_re = re.compile(r' *, *') @classmethod def _argument_split(cls, arguments): @@ -154,7 +193,7 @@ class Inputs: # Groups: 1=macro name, 2=type, 3=argument list (optional). _header_line_re = \ re.compile(r'#define +' + - r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' + + r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' + r'(?:\(([^\n()]*)\))?') # Regex of macro names to exclude. _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') @@ -167,10 +206,6 @@ class Inputs: # Auxiliary macro whose name doesn't fit the usual patterns for # auxiliary macros. 'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE', - # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script - # currently doesn't support them. - 'PSA_ALG_ECDH', - 'PSA_ALG_FFDH', # Deprecated aliases. 'PSA_ERROR_UNKNOWN_ERROR', 'PSA_ERROR_OCCUPIED_SLOT', @@ -184,6 +219,7 @@ class Inputs: if not m: return name = m.group(1) + self.all_declared.add(name) if re.search(self._excluded_name_re, name) or \ name in self._excluded_names: return @@ -200,26 +236,34 @@ class Inputs: for line in lines: self.parse_header_line(line) + _macro_identifier_re = r'[A-Z]\w+' + def generate_undeclared_names(self, expr): + for name in re.findall(self._macro_identifier_re, expr): + if name not in self.all_declared: + yield name + + def accept_test_case_line(self, function, argument): + #pylint: disable=unused-argument + undeclared = list(self.generate_undeclared_names(argument)) + if undeclared: + raise Exception('Undeclared names in test case', undeclared) + return True + def add_test_case_line(self, function, argument): """Parse a test case data line, looking for algorithm metadata tests.""" + sets = [] if function.endswith('_algorithm'): - # As above, ECDH and FFDH algorithms are excluded for now. - # Support for them will be added in the future. - if 'ECDH' in argument or 'FFDH' in argument: - return - self.algorithms.add(argument) - if function == 'hash_algorithm': - self.hash_algorithms.add(argument) - elif function in ['mac_algorithm', 'hmac_algorithm']: - self.mac_algorithms.add(argument) - elif function == 'aead_algorithm': - self.aead_algorithms.add(argument) - elif function == 'key_type': - self.key_types.add(argument) - elif function == 'ecc_key_types': - self.ecc_curves.add(argument) - elif function == 'dh_key_types': - self.dh_groups.add(argument) + sets.append(self.algorithms) + if function == 'key_agreement_algorithm' and \ + argument.startswith('PSA_ALG_KEY_AGREEMENT('): + # We only want *raw* key agreement algorithms as such, so + # exclude ones that are already chained with a KDF. + # Keep the expression as one to test as an algorithm. + function = 'other_algorithm' + sets += self.table_by_test_function[function] + if self.accept_test_case_line(function, argument): + for s in sets: + s.add(argument) # Regex matching a *.data line containing a test function call and # its arguments. The actual definition is partly positional, but this @@ -233,9 +277,9 @@ class Inputs: if m: self.add_test_case_line(m.group(1), m.group(2)) -def gather_inputs(headers, test_suites): +def gather_inputs(headers, test_suites, inputs_class=Inputs): """Read the list of inputs to test psa_constant_names with.""" - inputs = Inputs() + inputs = inputs_class() for header in headers: inputs.parse_header(header) for test_cases in test_suites: @@ -252,8 +296,10 @@ def remove_file_if_exists(filename): except OSError: pass -def run_c(options, type_word, names): - """Generate and run a program to print out numerical values for names.""" +def run_c(type_word, expressions, include_path=None, keep_c=False): + """Generate and run a program to print out numerical values for expressions.""" + if include_path is None: + include_path = [] if type_word == 'status': cast_to = 'long' printf_format = '%ld' @@ -278,18 +324,18 @@ def run_c(options, type_word, names): int main(void) { ''') - for name in names: + for expr in expressions: c_file.write(' printf("{}\\n", ({}) {});\n' - .format(printf_format, cast_to, name)) + .format(printf_format, cast_to, expr)) c_file.write(''' return 0; } ''') c_file.close() cc = os.getenv('CC', 'cc') subprocess.check_call([cc] + - ['-I' + dir for dir in options.include] + + ['-I' + dir for dir in include_path] + ['-o', exe_name, c_name]) - if options.keep_c: + if keep_c: sys.stderr.write('List of {} tests kept at {}\n' .format(type_word, c_name)) else: @@ -302,76 +348,101 @@ int main(void) NORMALIZE_STRIP_RE = re.compile(r'\s+') def normalize(expr): """Normalize the C expression so as not to care about trivial differences. + Currently "trivial differences" means whitespace. """ - expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr)) - return expr.strip().split('\n') + return re.sub(NORMALIZE_STRIP_RE, '', expr) -def do_test(options, inputs, type_word, names): - """Test psa_constant_names for the specified type. - Run program on names. - Use inputs to figure out what arguments to pass to macros that - take arguments. +def collect_values(inputs, type_word, include_path=None, keep_c=False): + """Generate expressions using known macro names and calculate their values. + + Return a list of pairs of (expr, value) where expr is an expression and + value is a string representation of its integer value. """ - names = sorted(itertools.chain(*map(inputs.distribute_arguments, names))) - values = run_c(options, type_word, names) - output = subprocess.check_output([options.program, type_word] + values) - outputs = output.decode('ascii').strip().split('\n') - errors = [(type_word, name, value, output) - for (name, value, output) in zip(names, values, outputs) - if normalize(name) != normalize(output)] - return len(names), errors + names = inputs.get_names(type_word) + expressions = sorted(inputs.generate_expressions(names)) + values = run_c(type_word, expressions, + include_path=include_path, keep_c=keep_c) + return expressions, values -def report_errors(errors): - """Describe each case where the output is not as expected.""" - for type_word, name, value, output in errors: - print('For {} "{}", got "{}" (value: {})' - .format(type_word, name, output, value)) +class Tests: + """An object representing tests and their results.""" -def run_tests(options, inputs): - """Run psa_constant_names on all the gathered inputs. - Return a tuple (count, errors) where count is the total number of inputs - that were tested and errors is the list of cases where the output was - not as expected. - """ - count = 0 - errors = [] - for type_word, names in [('status', inputs.statuses), - ('algorithm', inputs.algorithms), - ('ecc_curve', inputs.ecc_curves), - ('dh_group', inputs.dh_groups), - ('key_type', inputs.key_types), - ('key_usage', inputs.key_usage_flags)]: - c, e = do_test(options, inputs, type_word, names) - count += c - errors += e - return count, errors + Error = namedtuple('Error', + ['type', 'expression', 'value', 'output']) + + def __init__(self, options): + self.options = options + self.count = 0 + self.errors = [] + + def run_one(self, inputs, type_word): + """Test psa_constant_names for the specified type. + + Run the program on the names for this type. + Use the inputs to figure out what arguments to pass to macros that + take arguments. + """ + expressions, values = collect_values(inputs, type_word, + include_path=self.options.include, + keep_c=self.options.keep_c) + output = subprocess.check_output([self.options.program, type_word] + + values) + outputs = output.decode('ascii').strip().split('\n') + self.count += len(expressions) + for expr, value, output in zip(expressions, values, outputs): + if normalize(expr) != normalize(output): + self.errors.append(self.Error(type=type_word, + expression=expr, + value=value, + output=output)) + + def run_all(self, inputs): + """Run psa_constant_names on all the gathered inputs.""" + for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', + 'key_type', 'key_usage']: + self.run_one(inputs, type_word) + + def report(self, out): + """Describe each case where the output is not as expected. + + Write the errors to ``out``. + Also write a total. + """ + for error in self.errors: + out.write('For {} "{}", got "{}" (value: {})\n' + .format(error.type, error.expression, + error.output, error.value)) + out.write('{} test cases'.format(self.count)) + if self.errors: + out.write(', {} FAIL\n'.format(len(self.errors))) + else: + out.write(' PASS\n') + +HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h'] +TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data'] def main(): parser = argparse.ArgumentParser(description=globals()['__doc__']) parser.add_argument('--include', '-I', action='append', default=['include'], help='Directory for header files') - parser.add_argument('--program', - default='programs/psa/psa_constant_names', - help='Program to test') parser.add_argument('--keep-c', action='store_true', dest='keep_c', default=False, help='Keep the intermediate C file') parser.add_argument('--no-keep-c', action='store_false', dest='keep_c', help='Don\'t keep the intermediate C file (default)') + parser.add_argument('--program', + default='programs/psa/psa_constant_names', + help='Program to test') options = parser.parse_args() - headers = [os.path.join(options.include[0], 'psa', h) - for h in ['crypto.h', 'crypto_extra.h', 'crypto_values.h']] - test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data'] - inputs = gather_inputs(headers, test_suites) - count, errors = run_tests(options, inputs) - report_errors(errors) - if errors == []: - print('{} test cases PASS'.format(count)) - else: - print('{} test cases, {} FAIL'.format(count, len(errors))) + headers = [os.path.join(options.include[0], h) for h in HEADERS] + inputs = gather_inputs(headers, TEST_SUITES) + tests = Tests(options) + tests.run_all(inputs) + tests.report(sys.stdout) + if tests.errors: exit(1) if __name__ == '__main__': diff --git a/tests/suites/test_suite_psa_crypto_metadata.data b/tests/suites/test_suite_psa_crypto_metadata.data index e989895d2..9cdee0353 100644 --- a/tests/suites/test_suite_psa_crypto_metadata.data +++ b/tests/suites/test_suite_psa_crypto_metadata.data @@ -262,6 +262,26 @@ Key derivation: HKDF using SHA-256 depends_on:MBEDTLS_SHA256_C key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_256 ):ALG_IS_HKDF +Key derivation: HKDF using SHA-384 +depends_on:MBEDTLS_SHA512_C +key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_384 ):ALG_IS_HKDF + +Key derivation: TLS 1.2 PRF using SHA-256 +depends_on:MBEDTLS_SHA256_C +key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PRF + +Key derivation: TLS 1.2 PRF using SHA-384 +depends_on:MBEDTLS_SHA512_C +key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PRF + +Key derivation: TLS 1.2 PSK-to-MS using SHA-256 +depends_on:MBEDTLS_SHA256_C +key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PSK_TO_MS + +Key derivation: TLS 1.2 PSK-to-MS using SHA-384 +depends_on:MBEDTLS_SHA512_C +key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PSK_TO_MS + Key agreement: FFDH, raw output depends_on:MBEDTLS_DHM_C key_agreement_algorithm:PSA_ALG_FFDH:ALG_IS_FFDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_FFDH:PSA_ALG_CATEGORY_KEY_DERIVATION @@ -270,6 +290,10 @@ Key agreement: FFDH, HKDF using SHA-256 depends_on:MBEDTLS_DHM_C key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 ) +Key agreement: FFDH, HKDF using SHA-384 +depends_on:MBEDTLS_DHM_C +key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 ) + Key agreement: ECDH, raw output depends_on:MBEDTLS_ECDH_C key_agreement_algorithm:PSA_ALG_ECDH:ALG_IS_ECDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_ECDH:PSA_ALG_CATEGORY_KEY_DERIVATION @@ -278,6 +302,10 @@ Key agreement: ECDH, HKDF using SHA-256 depends_on:MBEDTLS_ECDH_C key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 ) +Key agreement: ECDH, HKDF using SHA-384 +depends_on:MBEDTLS_ECDH_C +key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 ) + Key type: raw data key_type:PSA_KEY_TYPE_RAW_DATA:KEY_TYPE_IS_UNSTRUCTURED diff --git a/tests/suites/test_suite_psa_crypto_metadata.function b/tests/suites/test_suite_psa_crypto_metadata.function index a9f1b3938..3a9347e2f 100644 --- a/tests/suites/test_suite_psa_crypto_metadata.function +++ b/tests/suites/test_suite_psa_crypto_metadata.function @@ -37,6 +37,8 @@ #define ALG_IS_WILDCARD ( 1u << 19 ) #define ALG_IS_RAW_KEY_AGREEMENT ( 1u << 20 ) #define ALG_IS_AEAD_ON_BLOCK_CIPHER ( 1u << 21 ) +#define ALG_IS_TLS12_PRF ( 1u << 22 ) +#define ALG_IS_TLS12_PSK_TO_MS ( 1u << 23 ) /* Flags for key type classification macros. There is a flag for every * key type classification macro PSA_KEY_TYPE_IS_xxx except for some that