From 34719d41f7046ed5bbaefbf209ee97f3603de39c Mon Sep 17 00:00:00 2001 From: Alex Vlasov Date: Wed, 1 May 2024 05:07:40 +0400 Subject: [PATCH] refactoring --- .../instance_generator.py | 152 ++++++++---------- 1 file changed, 69 insertions(+), 83 deletions(-) diff --git a/tests/generators/fork_choice_generated/instance_generator.py b/tests/generators/fork_choice_generated/instance_generator.py index 44a3c9efed..8d8660d83f 100644 --- a/tests/generators/fork_choice_generated/instance_generator.py +++ b/tests/generators/fork_choice_generated/instance_generator.py @@ -5,7 +5,7 @@ from eth2spec.gen_helpers.gen_base.gen_typing import TestCase, TestProvider from itertools import product from toolz.dicttoolz import merge -from typing import Iterable +from typing import Iterable, Callable from importlib import import_module from eth2spec.utils import bls from eth2spec.test.helpers.typing import SpecForkName, PresetBaseName @@ -23,6 +23,48 @@ presets = [MINIMAL] +def _create_providers(test_name: str, /, + forks: Iterable[SpecForkName], + presets: Iterable[PresetBaseName], + debug: bool, + initial_seed: int, + solutions, + number_of_variations: int, + number_of_mutations: int, + test_fn: Callable, + ) -> Iterable[TestProvider]: + def prepare_fn() -> None: + bls.use_milagro() + return + + def make_cases_fn() -> Iterable[TestCase]: + seeds = [initial_seed] + if number_of_variations > 1: + rnd = random.Random(initial_seed) + seeds = [rnd.randint(1, 10000) for _ in range(number_of_variations)] + seeds[0] = initial_seed + + for i, solution in enumerate(solutions): + for seed in seeds: + for fork_name in forks: + for preset_name in presets: + spec = spec_targets[preset_name][fork_name] + mutation_generator = MutatorsGenerator( + spec, seed, number_of_mutations, + lambda: test_fn(fork_name, preset_name, seed, solution), + debug=debug) + for j in range(1 + number_of_mutations): + yield TestCase(fork_name=fork_name, + preset_name=preset_name, + runner_name=GENERATOR_NAME, + handler_name=test_name, + suite_name='fork_choice', + case_name=test_name + '_' + str(i) + '_' + str(seed) + '_' + str(j), + case_fn=mutation_generator.next_test_case) + + yield TestProvider(prepare=prepare_fn, make_cases=make_cases_fn) + + def _import_block_tree_test_fn(): src = import_module('eth2spec.test.phase0.fork_choice.test_sm_links_tree_model') print("generating test vectors from tests source: %s" % src.__name__) @@ -80,50 +122,22 @@ def _create_block_tree_providers(test_name: str, /, number_of_mutations: int, with_attester_slashings: bool, with_invalid_messages: bool) -> Iterable[TestProvider]: - def prepare_fn() -> None: - bls.use_milagro() - return + _test_fn = _import_block_tree_test_fn() - def make_cases_fn() -> Iterable[TestCase]: - _test_fn = _import_block_tree_test_fn() - - def test_fn(phase: str, preset: str, seed: int, solution): - return _test_fn(generator_mode=True, - phase=phase, - preset=preset, - bls_active=BLS_ACTIVE, - debug=debug, - seed=seed, - sm_links=solution['sm_links'], - block_parents=solution['block_parents'], - with_attester_slashings=with_attester_slashings, - with_invalid_messages=with_invalid_messages) + def test_fn(phase: str, preset: str, seed: int, solution): + return _test_fn(generator_mode=True, + phase=phase, + preset=preset, + bls_active=BLS_ACTIVE, + debug=debug, + seed=seed, + sm_links=solution['sm_links'], + block_parents=solution['block_parents'], + with_attester_slashings=with_attester_slashings, + with_invalid_messages=with_invalid_messages) - seeds = [initial_seed] - if number_of_variations > 1: - rnd = random.Random(initial_seed) - seeds = [rnd.randint(1, 10000) for _ in range(number_of_variations)] - seeds[0] = initial_seed - - for i, solution in enumerate(solutions): - for seed in seeds: - for fork_name in forks: - for preset_name in presets: - spec = spec_targets[preset_name][fork_name] - mutation_generator = MutatorsGenerator( - spec, seed, number_of_mutations, - lambda: test_fn(fork_name, preset_name, seed, solution), - debug=debug) - for j in range(1 + number_of_mutations): - yield TestCase(fork_name=fork_name, - preset_name=preset_name, - runner_name=GENERATOR_NAME, - handler_name=test_name, - suite_name='fork_choice', - case_name=test_name + '_' + str(i) + '_' + str(seed) + '_' + str(j), - case_fn=mutation_generator.next_test_case) - - yield TestProvider(prepare=prepare_fn, make_cases=make_cases_fn) + yield from _create_providers( + test_name, forks, presets, debug, initial_seed, solutions, number_of_variations, number_of_mutations, test_fn) def _import_block_cover_test_fn(): @@ -215,46 +229,18 @@ def _create_block_cover_providers(test_name: str, /, solutions, number_of_variations: int, number_of_mutations: int) -> Iterable[TestProvider]: - def prepare_fn() -> None: - bls.use_milagro() - return - - def make_cases_fn() -> Iterable[TestCase]: - _test_fn = _import_block_cover_test_fn() - def test_fn(phase: str, preset: str, seed: int, solution): - return _test_fn(generator_mode=True, - phase=phase, - preset=preset, - bls_active=BLS_ACTIVE, - debug=debug, - seed=seed, - model_params=solution) - - seeds = [initial_seed] - if number_of_variations > 1: - rnd = random.Random(initial_seed) - seeds = [rnd.randint(1, 10000) for _ in range(number_of_variations)] - seeds[0] = initial_seed - - for i, solution in enumerate(solutions): - for seed in seeds: - for fork_name in forks: - for preset_name in presets: - spec = spec_targets[preset_name][fork_name] - mutation_generator = MutatorsGenerator( - spec, seed, number_of_mutations, - lambda: test_fn(fork_name, preset_name, seed, solution), - debug=debug) - for j in range(1 + number_of_mutations): - yield TestCase(fork_name=fork_name, - preset_name=preset_name, - runner_name=GENERATOR_NAME, - handler_name=test_name, - suite_name='fork_choice', - case_name=test_name + '_' + str(i) + '_' + str(seed) + '_' + str(j), - case_fn=mutation_generator.next_test_case) - - yield TestProvider(prepare=prepare_fn, make_cases=make_cases_fn) + _test_fn = _import_block_cover_test_fn() + def test_fn(phase: str, preset: str, seed: int, solution): + return _test_fn(generator_mode=True, + phase=phase, + preset=preset, + bls_active=BLS_ACTIVE, + debug=debug, + seed=seed, + model_params=solution) + + yield from _create_providers( + test_name, forks, presets, debug, initial_seed, solutions, number_of_variations, number_of_mutations, test_fn) if __name__ == "__main__":