Merge pull request #324 from gilles-peskine-arm/psa-test_psa_constant_names-refactor_and_ka

test_psa_constant_names: support key agreement, better code structure
This commit is contained in:
Gilles Peskine 2019-11-26 16:01:31 +01:00 committed by GitHub
commit 4eca19bbd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 184 additions and 83 deletions

View file

@ -8,6 +8,7 @@ or 1 (with a Python backtrace) if there was an operational error.
""" """
import argparse import argparse
from collections import namedtuple
import itertools import itertools
import os import os
import platform import platform
@ -60,12 +61,15 @@ class read_file_lines:
from exc_value from exc_value
class Inputs: class Inputs:
# pylint: disable=too-many-instance-attributes
"""Accumulate information about macros to test. """Accumulate information about macros to test.
This includes macro names as well as information about their arguments This includes macro names as well as information about their arguments
when applicable. when applicable.
""" """
def __init__(self): def __init__(self):
self.all_declared = set()
# Sets of names per type # Sets of names per type
self.statuses = set(['PSA_SUCCESS']) self.statuses = set(['PSA_SUCCESS'])
self.algorithms = set(['0xffffffff']) self.algorithms = set(['0xffffffff'])
@ -86,11 +90,30 @@ class Inputs:
self.table_by_prefix = { self.table_by_prefix = {
'ERROR': self.statuses, 'ERROR': self.statuses,
'ALG': self.algorithms, 'ALG': self.algorithms,
'CURVE': self.ecc_curves, 'ECC_CURVE': self.ecc_curves,
'GROUP': self.dh_groups, 'DH_GROUP': self.dh_groups,
'KEY_TYPE': self.key_types, 'KEY_TYPE': self.key_types,
'KEY_USAGE': self.key_usage_flags, 'KEY_USAGE': self.key_usage_flags,
} }
# Test functions
self.table_by_test_function = {
# Any function ending in _algorithm also gets added to
# self.algorithms.
'key_type': [self.key_types],
'ecc_key_types': [self.ecc_curves],
'dh_key_types': [self.dh_groups],
'hash_algorithm': [self.hash_algorithms],
'mac_algorithm': [self.mac_algorithms],
'cipher_algorithm': [],
'hmac_algorithm': [self.mac_algorithms],
'aead_algorithm': [self.aead_algorithms],
'key_derivation_algorithm': [self.kdf_algorithms],
'key_agreement_algorithm': [self.ka_algorithms],
'asymmetric_signature_algorithm': [],
'asymmetric_signature_wildcard': [self.algorithms],
'asymmetric_encryption_algorithm': [],
'other_algorithm': [],
}
# macro name -> list of argument names # macro name -> list of argument names
self.argspecs = {} self.argspecs = {}
# argument name -> list of values # argument name -> list of values
@ -99,8 +122,20 @@ class Inputs:
'tag_length': ['1', '63'], 'tag_length': ['1', '63'],
} }
def get_names(self, type_word):
"""Return the set of known names of values of the given type."""
return {
'status': self.statuses,
'algorithm': self.algorithms,
'ecc_curve': self.ecc_curves,
'dh_group': self.dh_groups,
'key_type': self.key_types,
'key_usage': self.key_usage_flags,
}[type_word]
def gather_arguments(self): def gather_arguments(self):
"""Populate the list of values for macro arguments. """Populate the list of values for macro arguments.
Call this after parsing all the inputs. Call this after parsing all the inputs.
""" """
self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
@ -118,6 +153,7 @@ class Inputs:
def distribute_arguments(self, name): def distribute_arguments(self, name):
"""Generate macro calls with each tested argument set. """Generate macro calls with each tested argument set.
If name is a macro without arguments, just yield "name". If name is a macro without arguments, just yield "name".
If name is a macro with arguments, yield a series of If name is a macro with arguments, yield a series of
"name(arg1,...,argN)" where each argument takes each possible "name(arg1,...,argN)" where each argument takes each possible
@ -145,6 +181,9 @@ class Inputs:
except BaseException as e: except BaseException as e:
raise Exception('distribute_arguments({})'.format(name)) from e raise Exception('distribute_arguments({})'.format(name)) from e
def generate_expressions(self, names):
return itertools.chain(*map(self.distribute_arguments, names))
_argument_split_re = re.compile(r' *, *') _argument_split_re = re.compile(r' *, *')
@classmethod @classmethod
def _argument_split(cls, arguments): def _argument_split(cls, arguments):
@ -154,7 +193,7 @@ class Inputs:
# Groups: 1=macro name, 2=type, 3=argument list (optional). # Groups: 1=macro name, 2=type, 3=argument list (optional).
_header_line_re = \ _header_line_re = \
re.compile(r'#define +' + re.compile(r'#define +' +
r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' + r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
r'(?:\(([^\n()]*)\))?') r'(?:\(([^\n()]*)\))?')
# Regex of macro names to exclude. # Regex of macro names to exclude.
_excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
@ -167,10 +206,6 @@ class Inputs:
# Auxiliary macro whose name doesn't fit the usual patterns for # Auxiliary macro whose name doesn't fit the usual patterns for
# auxiliary macros. # auxiliary macros.
'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE', 'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE',
# PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script
# currently doesn't support them.
'PSA_ALG_ECDH',
'PSA_ALG_FFDH',
# Deprecated aliases. # Deprecated aliases.
'PSA_ERROR_UNKNOWN_ERROR', 'PSA_ERROR_UNKNOWN_ERROR',
'PSA_ERROR_OCCUPIED_SLOT', 'PSA_ERROR_OCCUPIED_SLOT',
@ -184,6 +219,7 @@ class Inputs:
if not m: if not m:
return return
name = m.group(1) name = m.group(1)
self.all_declared.add(name)
if re.search(self._excluded_name_re, name) or \ if re.search(self._excluded_name_re, name) or \
name in self._excluded_names: name in self._excluded_names:
return return
@ -200,26 +236,34 @@ class Inputs:
for line in lines: for line in lines:
self.parse_header_line(line) self.parse_header_line(line)
_macro_identifier_re = r'[A-Z]\w+'
def generate_undeclared_names(self, expr):
for name in re.findall(self._macro_identifier_re, expr):
if name not in self.all_declared:
yield name
def accept_test_case_line(self, function, argument):
#pylint: disable=unused-argument
undeclared = list(self.generate_undeclared_names(argument))
if undeclared:
raise Exception('Undeclared names in test case', undeclared)
return True
def add_test_case_line(self, function, argument): def add_test_case_line(self, function, argument):
"""Parse a test case data line, looking for algorithm metadata tests.""" """Parse a test case data line, looking for algorithm metadata tests."""
sets = []
if function.endswith('_algorithm'): if function.endswith('_algorithm'):
# As above, ECDH and FFDH algorithms are excluded for now. sets.append(self.algorithms)
# Support for them will be added in the future. if function == 'key_agreement_algorithm' and \
if 'ECDH' in argument or 'FFDH' in argument: argument.startswith('PSA_ALG_KEY_AGREEMENT('):
return # We only want *raw* key agreement algorithms as such, so
self.algorithms.add(argument) # exclude ones that are already chained with a KDF.
if function == 'hash_algorithm': # Keep the expression as one to test as an algorithm.
self.hash_algorithms.add(argument) function = 'other_algorithm'
elif function in ['mac_algorithm', 'hmac_algorithm']: sets += self.table_by_test_function[function]
self.mac_algorithms.add(argument) if self.accept_test_case_line(function, argument):
elif function == 'aead_algorithm': for s in sets:
self.aead_algorithms.add(argument) s.add(argument)
elif function == 'key_type':
self.key_types.add(argument)
elif function == 'ecc_key_types':
self.ecc_curves.add(argument)
elif function == 'dh_key_types':
self.dh_groups.add(argument)
# Regex matching a *.data line containing a test function call and # Regex matching a *.data line containing a test function call and
# its arguments. The actual definition is partly positional, but this # its arguments. The actual definition is partly positional, but this
@ -233,9 +277,9 @@ class Inputs:
if m: if m:
self.add_test_case_line(m.group(1), m.group(2)) self.add_test_case_line(m.group(1), m.group(2))
def gather_inputs(headers, test_suites): def gather_inputs(headers, test_suites, inputs_class=Inputs):
"""Read the list of inputs to test psa_constant_names with.""" """Read the list of inputs to test psa_constant_names with."""
inputs = Inputs() inputs = inputs_class()
for header in headers: for header in headers:
inputs.parse_header(header) inputs.parse_header(header)
for test_cases in test_suites: for test_cases in test_suites:
@ -252,8 +296,10 @@ def remove_file_if_exists(filename):
except OSError: except OSError:
pass pass
def run_c(options, type_word, names): def run_c(type_word, expressions, include_path=None, keep_c=False):
"""Generate and run a program to print out numerical values for names.""" """Generate and run a program to print out numerical values for expressions."""
if include_path is None:
include_path = []
if type_word == 'status': if type_word == 'status':
cast_to = 'long' cast_to = 'long'
printf_format = '%ld' printf_format = '%ld'
@ -278,18 +324,18 @@ def run_c(options, type_word, names):
int main(void) int main(void)
{ {
''') ''')
for name in names: for expr in expressions:
c_file.write(' printf("{}\\n", ({}) {});\n' c_file.write(' printf("{}\\n", ({}) {});\n'
.format(printf_format, cast_to, name)) .format(printf_format, cast_to, expr))
c_file.write(''' return 0; c_file.write(''' return 0;
} }
''') ''')
c_file.close() c_file.close()
cc = os.getenv('CC', 'cc') cc = os.getenv('CC', 'cc')
subprocess.check_call([cc] + subprocess.check_call([cc] +
['-I' + dir for dir in options.include] + ['-I' + dir for dir in include_path] +
['-o', exe_name, c_name]) ['-o', exe_name, c_name])
if options.keep_c: if keep_c:
sys.stderr.write('List of {} tests kept at {}\n' sys.stderr.write('List of {} tests kept at {}\n'
.format(type_word, c_name)) .format(type_word, c_name))
else: else:
@ -302,76 +348,101 @@ int main(void)
NORMALIZE_STRIP_RE = re.compile(r'\s+') NORMALIZE_STRIP_RE = re.compile(r'\s+')
def normalize(expr): def normalize(expr):
"""Normalize the C expression so as not to care about trivial differences. """Normalize the C expression so as not to care about trivial differences.
Currently "trivial differences" means whitespace. Currently "trivial differences" means whitespace.
""" """
expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr)) return re.sub(NORMALIZE_STRIP_RE, '', expr)
return expr.strip().split('\n')
def do_test(options, inputs, type_word, names): def collect_values(inputs, type_word, include_path=None, keep_c=False):
"""Generate expressions using known macro names and calculate their values.
Return a list of pairs of (expr, value) where expr is an expression and
value is a string representation of its integer value.
"""
names = inputs.get_names(type_word)
expressions = sorted(inputs.generate_expressions(names))
values = run_c(type_word, expressions,
include_path=include_path, keep_c=keep_c)
return expressions, values
class Tests:
"""An object representing tests and their results."""
Error = namedtuple('Error',
['type', 'expression', 'value', 'output'])
def __init__(self, options):
self.options = options
self.count = 0
self.errors = []
def run_one(self, inputs, type_word):
"""Test psa_constant_names for the specified type. """Test psa_constant_names for the specified type.
Run program on names.
Use inputs to figure out what arguments to pass to macros that Run the program on the names for this type.
Use the inputs to figure out what arguments to pass to macros that
take arguments. take arguments.
""" """
names = sorted(itertools.chain(*map(inputs.distribute_arguments, names))) expressions, values = collect_values(inputs, type_word,
values = run_c(options, type_word, names) include_path=self.options.include,
output = subprocess.check_output([options.program, type_word] + values) keep_c=self.options.keep_c)
output = subprocess.check_output([self.options.program, type_word] +
values)
outputs = output.decode('ascii').strip().split('\n') outputs = output.decode('ascii').strip().split('\n')
errors = [(type_word, name, value, output) self.count += len(expressions)
for (name, value, output) in zip(names, values, outputs) for expr, value, output in zip(expressions, values, outputs):
if normalize(name) != normalize(output)] if normalize(expr) != normalize(output):
return len(names), errors self.errors.append(self.Error(type=type_word,
expression=expr,
value=value,
output=output))
def report_errors(errors): def run_all(self, inputs):
"""Describe each case where the output is not as expected.""" """Run psa_constant_names on all the gathered inputs."""
for type_word, name, value, output in errors: for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
print('For {} "{}", got "{}" (value: {})' 'key_type', 'key_usage']:
.format(type_word, name, output, value)) self.run_one(inputs, type_word)
def run_tests(options, inputs): def report(self, out):
"""Run psa_constant_names on all the gathered inputs. """Describe each case where the output is not as expected.
Return a tuple (count, errors) where count is the total number of inputs
that were tested and errors is the list of cases where the output was Write the errors to ``out``.
not as expected. Also write a total.
""" """
count = 0 for error in self.errors:
errors = [] out.write('For {} "{}", got "{}" (value: {})\n'
for type_word, names in [('status', inputs.statuses), .format(error.type, error.expression,
('algorithm', inputs.algorithms), error.output, error.value))
('ecc_curve', inputs.ecc_curves), out.write('{} test cases'.format(self.count))
('dh_group', inputs.dh_groups), if self.errors:
('key_type', inputs.key_types), out.write(', {} FAIL\n'.format(len(self.errors)))
('key_usage', inputs.key_usage_flags)]: else:
c, e = do_test(options, inputs, type_word, names) out.write(' PASS\n')
count += c
errors += e HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
return count, errors TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
def main(): def main():
parser = argparse.ArgumentParser(description=globals()['__doc__']) parser = argparse.ArgumentParser(description=globals()['__doc__'])
parser.add_argument('--include', '-I', parser.add_argument('--include', '-I',
action='append', default=['include'], action='append', default=['include'],
help='Directory for header files') help='Directory for header files')
parser.add_argument('--program',
default='programs/psa/psa_constant_names',
help='Program to test')
parser.add_argument('--keep-c', parser.add_argument('--keep-c',
action='store_true', dest='keep_c', default=False, action='store_true', dest='keep_c', default=False,
help='Keep the intermediate C file') help='Keep the intermediate C file')
parser.add_argument('--no-keep-c', parser.add_argument('--no-keep-c',
action='store_false', dest='keep_c', action='store_false', dest='keep_c',
help='Don\'t keep the intermediate C file (default)') help='Don\'t keep the intermediate C file (default)')
parser.add_argument('--program',
default='programs/psa/psa_constant_names',
help='Program to test')
options = parser.parse_args() options = parser.parse_args()
headers = [os.path.join(options.include[0], 'psa', h) headers = [os.path.join(options.include[0], h) for h in HEADERS]
for h in ['crypto.h', 'crypto_extra.h', 'crypto_values.h']] inputs = gather_inputs(headers, TEST_SUITES)
test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data'] tests = Tests(options)
inputs = gather_inputs(headers, test_suites) tests.run_all(inputs)
count, errors = run_tests(options, inputs) tests.report(sys.stdout)
report_errors(errors) if tests.errors:
if errors == []:
print('{} test cases PASS'.format(count))
else:
print('{} test cases, {} FAIL'.format(count, len(errors)))
exit(1) exit(1)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -262,6 +262,26 @@ Key derivation: HKDF using SHA-256
depends_on:MBEDTLS_SHA256_C depends_on:MBEDTLS_SHA256_C
key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_256 ):ALG_IS_HKDF key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_256 ):ALG_IS_HKDF
Key derivation: HKDF using SHA-384
depends_on:MBEDTLS_SHA512_C
key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_384 ):ALG_IS_HKDF
Key derivation: TLS 1.2 PRF using SHA-256
depends_on:MBEDTLS_SHA256_C
key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PRF
Key derivation: TLS 1.2 PRF using SHA-384
depends_on:MBEDTLS_SHA512_C
key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PRF
Key derivation: TLS 1.2 PSK-to-MS using SHA-256
depends_on:MBEDTLS_SHA256_C
key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PSK_TO_MS
Key derivation: TLS 1.2 PSK-to-MS using SHA-384
depends_on:MBEDTLS_SHA512_C
key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PSK_TO_MS
Key agreement: FFDH, raw output Key agreement: FFDH, raw output
depends_on:MBEDTLS_DHM_C depends_on:MBEDTLS_DHM_C
key_agreement_algorithm:PSA_ALG_FFDH:ALG_IS_FFDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_FFDH:PSA_ALG_CATEGORY_KEY_DERIVATION key_agreement_algorithm:PSA_ALG_FFDH:ALG_IS_FFDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_FFDH:PSA_ALG_CATEGORY_KEY_DERIVATION
@ -270,6 +290,10 @@ Key agreement: FFDH, HKDF using SHA-256
depends_on:MBEDTLS_DHM_C depends_on:MBEDTLS_DHM_C
key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 ) key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 )
Key agreement: FFDH, HKDF using SHA-384
depends_on:MBEDTLS_DHM_C
key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 )
Key agreement: ECDH, raw output Key agreement: ECDH, raw output
depends_on:MBEDTLS_ECDH_C depends_on:MBEDTLS_ECDH_C
key_agreement_algorithm:PSA_ALG_ECDH:ALG_IS_ECDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_ECDH:PSA_ALG_CATEGORY_KEY_DERIVATION key_agreement_algorithm:PSA_ALG_ECDH:ALG_IS_ECDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_ECDH:PSA_ALG_CATEGORY_KEY_DERIVATION
@ -278,6 +302,10 @@ Key agreement: ECDH, HKDF using SHA-256
depends_on:MBEDTLS_ECDH_C depends_on:MBEDTLS_ECDH_C
key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 ) key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 )
Key agreement: ECDH, HKDF using SHA-384
depends_on:MBEDTLS_ECDH_C
key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 )
Key type: raw data Key type: raw data
key_type:PSA_KEY_TYPE_RAW_DATA:KEY_TYPE_IS_UNSTRUCTURED key_type:PSA_KEY_TYPE_RAW_DATA:KEY_TYPE_IS_UNSTRUCTURED

View file

@ -37,6 +37,8 @@
#define ALG_IS_WILDCARD ( 1u << 19 ) #define ALG_IS_WILDCARD ( 1u << 19 )
#define ALG_IS_RAW_KEY_AGREEMENT ( 1u << 20 ) #define ALG_IS_RAW_KEY_AGREEMENT ( 1u << 20 )
#define ALG_IS_AEAD_ON_BLOCK_CIPHER ( 1u << 21 ) #define ALG_IS_AEAD_ON_BLOCK_CIPHER ( 1u << 21 )
#define ALG_IS_TLS12_PRF ( 1u << 22 )
#define ALG_IS_TLS12_PSK_TO_MS ( 1u << 23 )
/* Flags for key type classification macros. There is a flag for every /* Flags for key type classification macros. There is a flag for every
* key type classification macro PSA_KEY_TYPE_IS_xxx except for some that * key type classification macro PSA_KEY_TYPE_IS_xxx except for some that