diff --git a/tests/scripts/generate_bignum_tests.py b/tests/scripts/generate_bignum_tests.py index 3453b6bc3..d156f56f8 100755 --- a/tests/scripts/generate_bignum_tests.py +++ b/tests/scripts/generate_bignum_tests.py @@ -73,6 +73,17 @@ def hex_to_int(val: str) -> int: def quote_str(val) -> str: return "\"{}\"".format(val) +def combination_pairs(values: List[T]) -> List[Tuple[T, T]]: + """Return all pair combinations from input values. + + The return value is cast, as older versions of mypy are unable to derive + the specific type returned by itertools.combinations_with_replacement. + """ + return typing.cast( + List[Tuple[T, T]], + list(itertools.combinations_with_replacement(values, 2)) + ) + class BignumTarget(test_generation.BaseTarget, metaclass=ABCMeta): #pylint: disable=abstract-method @@ -165,10 +176,7 @@ class BignumOperation(BignumTarget, metaclass=ABCMeta): Combinations are first generated from all input values, and then specific cases provided. """ - yield from typing.cast( - Iterator[Tuple[str, str]], - itertools.combinations_with_replacement(cls.input_values, 2) - ) + yield from combination_pairs(cls.input_values) yield from cls.input_cases @classmethod @@ -215,14 +223,11 @@ class BignumAdd(BignumOperation): symbol = "+" test_function = "mbedtls_mpi_add_mpi" test_name = "MPI add" - input_cases = typing.cast( - List[Tuple[str, str]], - list(itertools.combinations_with_replacement( - [ - "1c67967269c6", "9cde3", - "-1c67967269c6", "-9cde3", - ], 2 - )) + input_cases = combination_pairs( + [ + "1c67967269c6", "9cde3", + "-1c67967269c6", "-9cde3", + ] ) def result(self) -> str: