Bignum test: remove type restrictrion

The special case list type depends on the arity and the subclass. Remove
type restriction to make defining special case lists more flexible and natural.

Signed-off-by: Janos Follath <janos.follath@arm.com>
This commit is contained in:
Janos Follath 2022-11-19 12:48:17 +00:00
parent c4fca5de3e
commit 98edf21bb4
2 changed files with 21 additions and 5 deletions

View file

@ -15,7 +15,8 @@
# limitations under the License.
from abc import abstractmethod
from typing import Iterator, List, Tuple, TypeVar
from typing import Iterator, List, Tuple, TypeVar, Any
from itertools import chain
from . import test_case
from . import test_data_generation
@ -90,7 +91,7 @@ class OperationCommon(test_data_generation.BaseTest):
"""
symbol = ""
input_values = [] # type: List[str]
input_cases = [] # type: List[Tuple[str, str]]
input_cases = [] # type: List[Any]
unique_combinations_only = True
input_styles = ["variable", "arch_split"] # type: List[str]
input_style = "variable" # type: str
@ -200,7 +201,6 @@ class OperationCommon(test_data_generation.BaseTest):
for a in cls.input_values
for b in cls.input_values
)
yield from cls.input_cases
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
@ -212,14 +212,20 @@ class OperationCommon(test_data_generation.BaseTest):
test_objects = (cls(a, b, bits_in_limb=bil)
for a, b in cls.get_value_pairs()
for bil in cls.limb_sizes)
special_cases = (cls(*args, bits_in_limb=bil) # type: ignore
for args in cls.input_cases
for bil in cls.limb_sizes)
else:
test_objects = (cls(a, b)
for a, b in cls.get_value_pairs())
special_cases = (cls(*args) for args in cls.input_cases)
yield from (valid_test_object.create_test_case()
for valid_test_object in filter(
lambda test_object: test_object.is_valid,
test_objects
))
chain(test_objects, special_cases)
)
)
class ModOperationCommon(OperationCommon):

View file

@ -243,6 +243,16 @@ class BignumCoreMLA(BignumCoreOperation):
"\"{:x}\"".format(carry_8)
]
@classmethod
def get_value_pairs(cls) -> Iterator[Tuple[str, str]]:
"""Generator to yield pairs of inputs.
Combinations are first generated from all input values, and then
specific cases provided.
"""
yield from super().get_value_pairs()
yield from cls.input_cases
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
"""Override for additional scalar input."""