diff --git a/ebl/dictionary/web/word_search.py b/ebl/dictionary/web/word_search.py index 4e2743457..9b360eb0b 100644 --- a/ebl/dictionary/web/word_search.py +++ b/ebl/dictionary/web/word_search.py @@ -7,10 +7,10 @@ class WordSearch: def __init__(self, dictionary): self._dispatch = create_dispatcher( - { - frozenset(["query"]): lambda value: dictionary.search(**value), - frozenset(["lemma"]): lambda value: dictionary.search_lemma(**value), - } + [ + lambda query: dictionary.search(query), + lambda lemma: dictionary.search_lemma(lemma), + ] ) @falcon.before(require_scope, "read:words") diff --git a/ebl/dispatcher.py b/ebl/dispatcher.py index 566730005..d07cd957c 100644 --- a/ebl/dispatcher.py +++ b/ebl/dispatcher.py @@ -1,4 +1,5 @@ -from typing import Callable, Mapping, TypeVar, FrozenSet +from typing import Callable, FrozenSet, Mapping, Sequence, TypeVar +from inspect import signature class DispatchError(Exception): @@ -6,7 +7,7 @@ class DispatchError(Exception): T = TypeVar("T") -Command = Callable[[Mapping[str, str]], T] +Command = Callable[..., T] Dispatcher = Callable[[Mapping[str, str]], T] @@ -14,15 +15,21 @@ def get_parameter_names(parameters: Mapping[str, str]) -> FrozenSet[str]: return frozenset(parameters.keys()) -def create_dispatcher(commands: Mapping[FrozenSet[str], Command[T]]) -> Dispatcher[T]: +def create_dispatcher(commands: Sequence[Command[T]]) -> Dispatcher[T]: + command_map = { + frozenset(signature(command).parameters.keys()): command for command in commands + } + if len(command_map) != len(commands): + raise DispatchError("Duplicate arguments in commands.") + def get_command(parameter_names: FrozenSet[str]) -> Command[T]: try: - return commands[parameter_names] + return command_map[parameter_names] except KeyError as error: raise DispatchError(f"Invalid parameters {parameter_names}.") from error def dispatch(parameters: Mapping[str, str]) -> T: parameter_names = get_parameter_names(parameters) - return get_command(parameter_names)(parameters) + return get_command(parameter_names)(**parameters) return dispatch diff --git a/ebl/fragmentarium/web/fragment_search.py b/ebl/fragmentarium/web/fragment_search.py index c8db78551..7298b4331 100644 --- a/ebl/fragmentarium/web/fragment_search.py +++ b/ebl/fragmentarium/web/fragment_search.py @@ -23,38 +23,32 @@ def __init__( transliteration_query_factory: TransliterationQueryFactory, ): self._dispatch = create_dispatcher( - { - frozenset( - ["id", "pages"] - ): lambda value: finder.search_references_in_fragment_infos( - *self._validate_pages(**value) + [ + lambda id, pages: finder.search_references_in_fragment_infos( + *self._validate_pages(id, pages) ), - frozenset(["number"]): lambda value: finder.search(**value), - frozenset(["random"]): lambda _: finder.find_random(), - frozenset(["interesting"]): lambda _: finder.find_interesting(), - frozenset(["latest"]): lambda _: fragmentarium.find_latest(), - frozenset( - ["needsRevision"] - ): lambda _: fragmentarium.find_needs_revision(), - frozenset( - ["transliteration"] - ): lambda value: finder.search_transliteration( - transliteration_query_factory.create(**value) + lambda number: finder.search(number), + lambda random: finder.find_random(), + lambda interesting: finder.find_interesting(), + lambda latest: fragmentarium.find_latest(), + lambda needsRevision: fragmentarium.find_needs_revision(), + lambda transliteration: finder.search_transliteration( + transliteration_query_factory.create(transliteration) ), - } + ] ) @staticmethod def _validate_pages(id: str, pages: Union[str, None]) -> Tuple[str, str]: - if pages: - try: - int(pages) - return id, pages - except ValueError: - raise DataError(f'Pages "{pages}" not numeric.') - else: + if not pages: return id, "" + try: + int(pages) + return id, pages + except ValueError: + raise DataError(f'Pages "{pages}" not numeric.') + @falcon.before(require_scope, "read:fragments") def on_get(self, req: falcon.Request, resp: falcon.Response) -> None: infos = self._dispatch(req.params) diff --git a/ebl/signs/web/sign_search.py b/ebl/signs/web/sign_search.py index 08968633c..f52aeabe9 100644 --- a/ebl/signs/web/sign_search.py +++ b/ebl/signs/web/sign_search.py @@ -12,24 +12,18 @@ class SignsSearch: def __init__(self, signs: SignRepository): self._dispatch = create_dispatcher( - { - frozenset( - ["listsName", "listsNumber"] - ): lambda params: signs.search_by_lists_name( - params["listsName"], params["listsNumber"] + [ + lambda listsName, listsNumber: signs.search_by_lists_name( + listsName, listsNumber ), - frozenset(["value", "subIndex"]): lambda params: signs.search_all( - params["value"], params["subIndex"] + lambda value, subIndex: signs.search_all(value, subIndex), + lambda value, isIncludeHomophones, subIndex: signs.search_include_homophones( + value ), - frozenset( - ["value", "isIncludeHomophones", "subIndex"] - ): lambda params: signs.search_include_homophones(params["value"]), - frozenset( - ["value", "subIndex", "isComposite"] - ): lambda params: signs.search_composite_signs( - params["value"], params["subIndex"] + lambda value, subIndex, isComposite: signs.search_composite_signs( + value, subIndex ), - } + ] ) @staticmethod diff --git a/ebl/tests/test_dispatcher.py b/ebl/tests/test_dispatcher.py index c76bccfdf..311b979d2 100644 --- a/ebl/tests/test_dispatcher.py +++ b/ebl/tests/test_dispatcher.py @@ -2,40 +2,41 @@ from ebl.dispatcher import DispatchError, create_dispatcher -COMMANDS = { - frozenset(["a"]): lambda value: f"a_{''.join(value.values())}", - frozenset(["b"]): lambda value: f"b_{''.join(value.values())}", - frozenset(["a", "b"]): lambda value: f"a_b_{''.join(value.values())}", -} +COMMANDS = [lambda a: f"a: {a}", lambda b: f"b: {b}", lambda a, b: f"a: {a}, b: {b}"] DISPATCH = create_dispatcher(COMMANDS) @pytest.mark.parametrize( - "parameter, results", + "parameters, results", [ - ({"a": "value"}, "a_value"), - ({"b": "value"}, "b_value"), - ({"a": "value1", "b": "value2"}, "a_b_value1value2"), + ({"a": "value"}, "a: value"), + ({"b": "value"}, "b: value"), + ({"a": "value1", "b": "value2"}, "a: value1, b: value2"), ], ) -def test_valid_params(parameter, results): - assert DISPATCH(parameter) == results +def test_valid_parameters(parameters, results): + assert DISPATCH(parameters) == results @pytest.mark.parametrize( - "parameters", [{}, {"invalid": "parameter"}, {"a": "a", "b": "b", "c": "c"}] + "parameters", + [{}, {"invalid": "parameter"}, {"a": "a", "b": "b", "invalid": "parameter"}], ) -def test_invalid_params(parameters): +def test_invalid_parameters(parameters): with pytest.raises(DispatchError): DISPATCH(parameters) +def test_duplicate_parameters(): + with pytest.raises(DispatchError): + create_dispatcher([lambda duplicate: False, lambda duplicate: True]) + + def test_key_error_from_command(): - parameter = "parameter" message = "An error occurred in the command." - def raise_error(_): + def raise_error(): raise KeyError(message) with pytest.raises(KeyError, match=message): - create_dispatcher({frozenset({parameter}): raise_error})({parameter: "value"}) + create_dispatcher([raise_error])({})