From ac446c8a04b9dacdd546b791e87e4de5eed2427a Mon Sep 17 00:00:00 2001 From: Werner Lewis Date: Wed, 14 Sep 2022 15:12:46 +0100 Subject: [PATCH] Add combination_pairs helper function Wrapper function for itertools.combinations_with_replacement, with explicit cast due to imprecise typing with older versions of mypy. Signed-off-by: Werner Lewis --- tests/scripts/generate_bignum_tests.py | 29 +++++++++++++++----------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/scripts/generate_bignum_tests.py b/tests/scripts/generate_bignum_tests.py index 3453b6bc3..d156f56f8 100755 --- a/tests/scripts/generate_bignum_tests.py +++ b/tests/scripts/generate_bignum_tests.py @@ -73,6 +73,17 @@ def hex_to_int(val: str) -> int: def quote_str(val) -> str: return "\"{}\"".format(val) +def combination_pairs(values: List[T]) -> List[Tuple[T, T]]: + """Return all pair combinations from input values. + + The return value is cast, as older versions of mypy are unable to derive + the specific type returned by itertools.combinations_with_replacement. + """ + return typing.cast( + List[Tuple[T, T]], + list(itertools.combinations_with_replacement(values, 2)) + ) + class BignumTarget(test_generation.BaseTarget, metaclass=ABCMeta): #pylint: disable=abstract-method @@ -165,10 +176,7 @@ class BignumOperation(BignumTarget, metaclass=ABCMeta): Combinations are first generated from all input values, and then specific cases provided. """ - yield from typing.cast( - Iterator[Tuple[str, str]], - itertools.combinations_with_replacement(cls.input_values, 2) - ) + yield from combination_pairs(cls.input_values) yield from cls.input_cases @classmethod @@ -215,14 +223,11 @@ class BignumAdd(BignumOperation): symbol = "+" test_function = "mbedtls_mpi_add_mpi" test_name = "MPI add" - input_cases = typing.cast( - List[Tuple[str, str]], - list(itertools.combinations_with_replacement( - [ - "1c67967269c6", "9cde3", - "-1c67967269c6", "-9cde3", - ], 2 - )) + input_cases = combination_pairs( + [ + "1c67967269c6", "9cde3", + "-1c67967269c6", "-9cde3", + ] ) def result(self) -> str: