Add combination_pairs helper function

Wrapper function for itertools.combinations_with_replacement, with
explicit cast due to imprecise typing with older versions of mypy.

Signed-off-by: Werner Lewis <werner.lewis@arm.com>
This commit is contained in:
Werner Lewis 2022-09-14 15:12:46 +01:00
parent b6e809133d
commit ac446c8a04

View file

@ -73,6 +73,17 @@ def hex_to_int(val: str) -> int:
def quote_str(val) -> str: def quote_str(val) -> str:
return "\"{}\"".format(val) return "\"{}\"".format(val)
def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
"""Return all pair combinations from input values.
The return value is cast, as older versions of mypy are unable to derive
the specific type returned by itertools.combinations_with_replacement.
"""
return typing.cast(
List[Tuple[T, T]],
list(itertools.combinations_with_replacement(values, 2))
)
class BignumTarget(test_generation.BaseTarget, metaclass=ABCMeta): class BignumTarget(test_generation.BaseTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method #pylint: disable=abstract-method
@ -165,10 +176,7 @@ class BignumOperation(BignumTarget, metaclass=ABCMeta):
Combinations are first generated from all input values, and then Combinations are first generated from all input values, and then
specific cases provided. specific cases provided.
""" """
yield from typing.cast( yield from combination_pairs(cls.input_values)
Iterator[Tuple[str, str]],
itertools.combinations_with_replacement(cls.input_values, 2)
)
yield from cls.input_cases yield from cls.input_cases
@classmethod @classmethod
@ -215,14 +223,11 @@ class BignumAdd(BignumOperation):
symbol = "+" symbol = "+"
test_function = "mbedtls_mpi_add_mpi" test_function = "mbedtls_mpi_add_mpi"
test_name = "MPI add" test_name = "MPI add"
input_cases = typing.cast( input_cases = combination_pairs(
List[Tuple[str, str]],
list(itertools.combinations_with_replacement(
[ [
"1c67967269c6", "9cde3", "1c67967269c6", "9cde3",
"-1c67967269c6", "-9cde3", "-1c67967269c6", "-9cde3",
], 2 ]
))
) )
def result(self) -> str: def result(self) -> str: