diff --git a/fundi/debug.py b/fundi/debug.py index d7ac327..2ad0523 100644 --- a/fundi/debug.py +++ b/fundi/debug.py @@ -1,12 +1,13 @@ import typing import collections.abc +from fundi.scope import Scope from fundi.inject import injection_impl from fundi.types import CacheKey, CallableInfo def tree( - scope: collections.abc.Mapping[str, typing.Any], + scope: collections.abc.Mapping[str, typing.Any] | Scope, info: CallableInfo[typing.Any], cache: ( collections.abc.MutableMapping[CacheKey, collections.abc.Mapping[str, typing.Any]] | None @@ -20,6 +21,9 @@ def tree( :param cache: tree generation cache :return: Tree of dependencies """ + if not isinstance(scope, Scope): + scope = Scope.from_legacy(scope) + if cache is None: cache = {} @@ -36,7 +40,7 @@ def tree( def order( - scope: collections.abc.Mapping[str, typing.Any], + scope: collections.abc.Mapping[str, typing.Any] | Scope, info: CallableInfo[typing.Any], cache: ( collections.abc.MutableMapping[CacheKey, list[typing.Callable[..., typing.Any]]] | None @@ -50,6 +54,9 @@ def order( :param cache: solvation cache :return: order of dependencies """ + if not isinstance(scope, Scope): + scope = Scope.from_legacy(scope) + if cache is None: cache = {} diff --git a/fundi/exceptions.py b/fundi/exceptions.py index 3aac7fb..141d0b0 100644 --- a/fundi/exceptions.py +++ b/fundi/exceptions.py @@ -24,3 +24,14 @@ def __init__( super().__init__(f"Generator exited too early") self.function: FunctionType = function self.generator: AsyncGenerator[typing.Any] | Generator[typing.Any, None, None] = generator + + +class InvalidInitialValue(ValueError): + """ + Initial value passed to the ``Scope`` constructor is invalid + """ + + def __init__(self, value: typing.Any): + super().__init__( + f"Initial value is invalid: got {value!r}, but ``TypeFactory`` or ``TypeInstance`` expected" + ) diff --git a/fundi/inject.py b/fundi/inject.py index 09ea954..5cf6281 100644 --- a/fundi/inject.py +++ b/fundi/inject.py @@ -2,22 +2,23 @@ import contextlib import collections.abc +from fundi.scope import Scope from fundi.resolve import resolve from fundi.logging import get_logger from fundi.types import CacheKey, CallableInfo -from fundi.util import call_sync, call_async, add_injection_trace +from fundi.util import call_sync, call_async, add_injection_trace, callable_str injection_logger = get_logger("inject.injection") collection_logger = get_logger("inject.collection") def injection_impl( - scope: collections.abc.Mapping[str, typing.Any], + scope: Scope, info: CallableInfo[typing.Any], cache: collections.abc.MutableMapping[CacheKey, typing.Any], override: collections.abc.Mapping[typing.Callable[..., typing.Any], typing.Any] | None, ) -> collections.abc.Generator[ - tuple[collections.abc.Mapping[str, typing.Any], CallableInfo[typing.Any], bool], + tuple[collections.abc.Mapping[str, typing.Any] | Scope, CallableInfo[typing.Any], bool], typing.Any, None, ]: @@ -41,7 +42,7 @@ def injection_impl( if info.scopehook: collection_logger.debug("Calling scope hook for %r", info.call) - scope = dict(scope) + scope = scope.copy() info.scopehook(scope, info) values: dict[str, typing.Any] = {} @@ -57,7 +58,9 @@ def injection_impl( ), "Dependency expected, got None. This is a bug, please report at https://github.com/KuyuCode/fundi" collection_logger.debug("Passing %r upstream to be injected", dependency.call) - value = yield {**scope, "__fundi_parameter__": result.parameter}, dependency, True + + subscope = scope | Scope.from_legacy({"__fundi_parameter__": result.parameter}) + value = yield subscope, dependency, True if dependency.use_cache: collection_logger.debug( @@ -71,15 +74,19 @@ def injection_impl( collection_logger.debug("Passing %r side effects upstream to be injected", info.call) _values = values.copy() _info = info.copy(True) - _scope = {**scope} - for side_effect in info.side_effects: - yield { - **scope, + _scope = scope.copy() + + subscope = scope | Scope( + { "__values__": _values, "__dependant__": _info, "__scope__": _scope, "__fundi_parameter__": None, - }, side_effect, True + } + ) + + for side_effect in info.side_effects: + yield subscope, side_effect, True collection_logger.debug( "Passing %r with collected values %r to be called", info.call, values @@ -93,7 +100,7 @@ def injection_impl( def inject( - scope: collections.abc.Mapping[str, typing.Any], + scope: collections.abc.Mapping[str, typing.Any] | Scope, info: CallableInfo[typing.Any], stack: contextlib.ExitStack | None = None, cache: collections.abc.MutableMapping[CacheKey, typing.Any] | None = None, @@ -112,7 +119,14 @@ def inject( :return: result of callable """ if info.async_: - raise RuntimeError("Cannot process async functions in synchronous injection") + raise RuntimeError( + "Cannot process async functions ({func}) in synchronous injection".format( + func=callable_str(info.call) + ) + ) + + if not isinstance(scope, Scope): + scope = Scope.from_legacy(scope) if stack is None: injection_logger.debug("Exit stack not provided, creating own") @@ -143,7 +157,7 @@ def inject( inner_info.call, ) - return call_sync(stack, inner_info, inner_scope) + return call_sync(stack, inner_info, inner_scope) # type: ignore except Exception as exc: injection_logger.debug("Passing exception %r (%r) to downstream", exc, type(exc)) with contextlib.suppress(StopIteration): @@ -153,7 +167,7 @@ def inject( async def ainject( - scope: collections.abc.Mapping[str, typing.Any], + scope: collections.abc.Mapping[str, typing.Any] | Scope, info: CallableInfo[typing.Any], stack: contextlib.AsyncExitStack | None = None, cache: collections.abc.MutableMapping[CacheKey, typing.Any] | None = None, @@ -171,6 +185,9 @@ async def ainject( :param override: override dependencies :return: result of callable """ + if not isinstance(scope, Scope): + scope = Scope.from_legacy(scope) + if stack is None: injection_logger.debug("Exit stack not provided, creating own") async with contextlib.AsyncExitStack() as stack: @@ -201,9 +218,9 @@ async def ainject( ) if info.async_: - return await call_async(stack, inner_info, inner_scope) + return await call_async(stack, inner_info, inner_scope) # type: ignore - return call_sync(stack, inner_info, inner_scope) + return call_sync(stack, inner_info, inner_scope) # type: ignore except Exception as exc: injection_logger.debug("Passing exception %r (%r) to downstream", exc, type(exc)) with contextlib.suppress(StopIteration): diff --git a/fundi/inject.pyi b/fundi/inject.pyi index 6336ba0..792465e 100644 --- a/fundi/inject.pyi +++ b/fundi/inject.pyi @@ -2,6 +2,7 @@ import typing from typing import overload from collections.abc import Generator, AsyncGenerator, Mapping, MutableMapping, Awaitable +from fundi.scope import Scope from fundi.types import CacheKey, CallableInfo from contextlib import ( @@ -16,18 +17,18 @@ R = typing.TypeVar("R") ExitStack = AsyncExitStack | SyncExitStack def injection_impl( - scope: Mapping[str, typing.Any], + scope: Scope, info: CallableInfo[typing.Any], cache: MutableMapping[CacheKey, typing.Any], override: Mapping[typing.Callable[..., typing.Any], typing.Any] | None, ) -> Generator[ - tuple[Mapping[str, typing.Any], CallableInfo[typing.Any], bool], + tuple[Mapping[str, typing.Any] | Scope, CallableInfo[typing.Any], bool], typing.Any, None, ]: ... @overload def inject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[Generator[R, None, None]], stack: ExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -35,7 +36,7 @@ def inject( ) -> R: ... @overload def inject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[AbstractContextManager[R]], stack: ExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -43,7 +44,7 @@ def inject( ) -> R: ... @overload def inject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[R], stack: ExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -51,7 +52,7 @@ def inject( ) -> R: ... @overload async def ainject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[Generator[R, None, None]], stack: AsyncExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -59,7 +60,7 @@ async def ainject( ) -> R: ... @overload async def ainject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[AsyncGenerator[R, None]], stack: AsyncExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -67,7 +68,7 @@ async def ainject( ) -> R: ... @overload async def ainject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[Awaitable[R]], stack: AsyncExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -75,7 +76,7 @@ async def ainject( ) -> R: ... @overload async def ainject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[AbstractAsyncContextManager[R]], stack: AsyncExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -83,7 +84,7 @@ async def ainject( ) -> R: ... @overload async def ainject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[AbstractContextManager[R]], stack: AsyncExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, @@ -91,7 +92,7 @@ async def ainject( ) -> R: ... @overload async def ainject( - scope: Mapping[str, typing.Any], + scope: Mapping[str, typing.Any] | Scope, info: CallableInfo[R], stack: AsyncExitStack | None = None, cache: MutableMapping[CacheKey, typing.Any] | None = None, diff --git a/fundi/resolve.py b/fundi/resolve.py index a315b10..87849c5 100644 --- a/fundi/resolve.py +++ b/fundi/resolve.py @@ -2,7 +2,8 @@ import collections.abc from fundi.logging import get_logger -from fundi.util import normalize_annotation +from fundi.util import normalize_annotation, callable_str +from fundi.scope import Scope, NO_VALUE, TypeInstance, TypeFactory from fundi.types import CacheKey, CallableInfo, ParameterResult, Parameter logger = get_logger("resolve") @@ -17,7 +18,7 @@ def resolve_by_dependency( assert dependency is not None - logger.debug("Resolving %r using dependency %r", param.name, dependency.call) + logger.debug("Resolving %r using dependency %s", param.name, callable_str(dependency.call)) value = override.get(dependency.call) if value is not None: @@ -40,25 +41,35 @@ def resolve_by_dependency( return ParameterResult(param, None, dependency, resolved=False) -def resolve_by_type( - scope: collections.abc.Mapping[str, typing.Any], param: Parameter -) -> ParameterResult: +def resolve_by_type(scope: Scope, param: Parameter) -> ParameterResult: logger.debug("Resolving %r using annotation %r", param.name, param.annotation) type_options = normalize_annotation(param.annotation) - for value in scope.values(): - if not isinstance(value, type_options): + for type_ in type_options: + value = scope.resolve_by_type(typing.cast(type[typing.Any], type_)) + + if value is NO_VALUE: continue - logger.debug("Found value %r for %r: Annotation", value, param.name) + match value: + case TypeInstance(value): + logger.debug("Found type instance %r for %r", value, param.name) + return ParameterResult(param, value, None, resolved=True) + case TypeFactory(factory): + logger.debug( + "Found type factory %s for %r", + callable_str(factory.call), + param.name, + ) + return ParameterResult(param, None, factory, False) - return ParameterResult(param, value, None, resolved=True) + logger.debug("Not found value for %r using annotation %r", param.name, param.annotation) return ParameterResult(param, None, None, resolved=False) def resolve( - scope: collections.abc.Mapping[str, typing.Any], + scope: Scope, info: CallableInfo[typing.Any], cache: collections.abc.Mapping[CacheKey, typing.Any], override: collections.abc.Mapping[typing.Callable[..., typing.Any], typing.Any] | None = None, @@ -102,12 +113,17 @@ def resolve( if parameter.resolve_by_type: result = resolve_by_type(scope, parameter) + if result.dependency is not None: + yield resolve_by_dependency( + parameter.copy(from_=result.dependency), cache, override + ) + continue + if result.resolved: yield result continue - elif parameter.name in scope: - value = scope[parameter.name] + elif (value := scope.resolve_by_name(parameter.name)) is not NO_VALUE: logger.debug("Found value %r for %r: Name", value, parameter.name) yield ParameterResult(parameter, value, None, resolved=True) continue diff --git a/fundi/scope.py b/fundi/scope.py new file mode 100644 index 0000000..a80def8 --- /dev/null +++ b/fundi/scope.py @@ -0,0 +1,296 @@ +import typing +from itertools import chain +from dataclasses import dataclass +from collections.abc import Mapping + +from typing_extensions import NewType, overload, override + +from fundi import scan +from fundi.exceptions import InvalidInitialValue + +if typing.TYPE_CHECKING: + from fundi import CallableInfo + + +class NoValue: + """ + No value marker. Do not use this as a value! + """ + + +IGNORE_TYPES: tuple[type, ...] = ( + int, + str, + set, + bool, + dict, + list, + type, + float, + bytes, + tuple, + object, + bytearray, +) + +NO_VALUE = NoValue() + +T = typing.TypeVar("T") +S = typing.TypeVar("S", bound="Scope") + + +@dataclass +class TypeFactory(typing.Generic[T]): + """Marker type. Should be used to determine whether this value is a factory or instance of the type""" + + factory: "CallableInfo[T]" + + +@dataclass +class TypeInstance(typing.Generic[T]): + """Marker type. Should be used to determine whether this value is a factory or instance of the type""" + + instance: T + + +class Scope: + """ + Injection scope. + Stores and resolves dynamic values. + + Created to extend resolving mechanism with more features. + + Allows to store values by string keys, types, MROs. + Also, allows to create type factories - functions that create instances of the type. + """ + + def __init__(self, initial: dict[str | type | NewType, typing.Any] | None = None): + """ + Create the Scope. + + If the key of the ``initial`` is the string then the value is stored under that key. + + If the key is the type then the value is checked whether it is ``TypeInstance`` or ``TypeFactory``. + If the value is the ``TypeInstance`` - it is stored as instance of the type from key. + If the value is the ``TypeFactory`` - it is stored as factory of the type from key. + """ + initial = initial or {} + + self.values: dict[str, typing.Any] = {} + self.types: dict[type | NewType, typing.Any] = {} + self.factories: dict[type | NewType, "CallableInfo[typing.Any]"] = {} + + for key, value in initial.items(): + if isinstance(key, str): + self.values[key] = value + continue + + match value: + case TypeInstance(instance): + self.types[key] = instance + case TypeFactory(factory): + self.factories[key] = factory + case value: + raise InvalidInitialValue(value) + + def add_value(self, key: str, value: typing.Any) -> bool: + """ + Adds named value to the scope. + + Returns True if the value replaced existing one. + """ + if key in self.values: + self.values[key] = value + return True + + self.values[key] = value + return False + + def add_type( + self, + type_or_instance: typing.Any, + instance: typing.Any = NO_VALUE, + mro: bool = False, + ) -> None: + """ + Adds value by type to the scope. + + If ``instance`` is not provided - + uses ``type_or_instance`` as the instance and takes its type as the key. + + If ``type_or_instance`` is a tuple of types then each of the values is used as key for the ``instance`` and + ``mro`` parameter is disabled. + + If mro is True - adds this value by type and type's MRO(method resolution order) + ignoring last entry in the list. + + Returns nothing. + """ + if isinstance(instance, NoValue): + type_ = type(type_or_instance) + instance = type_or_instance + + elif isinstance(type_or_instance, tuple): + for type_ in type_or_instance: + self.types[type_] = instance + return None + + elif isinstance(type_or_instance, (type, NewType)): + type_ = type_or_instance + instance = instance + + else: + raise ValueError("Unable to detect type or value of the assignment") + + self.types[type_] = instance + + if mro: + for type_ in type_.mro()[1:-1]: + self.types[type_] = instance + + @overload + def add_factory( + self, type_: NewType | tuple[type[T], ...], factory: typing.Callable[..., T] + ) -> None: ... + @overload + def add_factory( + self, type_: type[T] | tuple[type[T], ...], factory: typing.Callable[..., T] + ) -> None: ... + def add_factory( + self, type_: type[T] | tuple[type[T], ...], factory: typing.Callable[..., T] + ) -> None: + """ + Adds factory of the type to the scope. + The factory can be any function that can be interpeted as a dependant. + + If the ``type_`` is the tuple of types then the factory is set for each type of that tuple. + + Returns nothing. + """ + if isinstance(type_, tuple): + for type_ in type_: + self.factories[type_] = scan(factory) + return + + self.factories[type_] = scan(factory) + + def update( + self, + mapping: ( + Mapping[ + str | type | NewType, + TypeInstance[typing.Any] | TypeFactory[typing.Any] | typing.Any, + ] + | None + ) = None, + **values: typing.Any, + ): + """ + Update this scope with provided values. + + ``mapping`` argument can be used to add multiple factories and type instances at a time. + And ``values`` argument can be used only for string based values. + """ + self.values.update(values) + + if mapping is None: + return None + + for key, value in mapping.items(): + if isinstance(key, str): + self.values[key] = value + continue + + match value: + case TypeInstance(value): + self.types[key] = value + case TypeFactory(factory): + self.factories[key] = factory + + def resolve_by_name(self, key: str, *, default: typing.Any = NO_VALUE) -> typing.Any | NoValue: + """ + Resolves value by the key name. + + Returns either value or the default. + The default is set to ``NoValue`` instance as the value may be None in the scope. + """ + return self.values.get(key, default) + + @overload + def resolve_by_type( + self, type_: NewType, default: T | NoValue = NO_VALUE + ) -> TypeInstance[typing.Any] | TypeFactory[typing.Any] | T | NoValue: ... + @overload + def resolve_by_type( + self, type_: type[T], default: T | NoValue = NO_VALUE + ) -> TypeInstance[T] | TypeFactory[T] | NoValue: ... + def resolve_by_type( + self, type_: typing.Any, default: typing.Any = NO_VALUE + ) -> TypeInstance[typing.Any] | TypeFactory[typing.Any] | NoValue | typing.Any: + """ + Resolves value or factory by the provided type. + + Returns value wrapped in ``TypeInstance``, factory wrapped in ``TypeFactory`` or the default value. + The default is set to ``NoValue`` instance as the value may be None in the scope. + + Resolution order: instance of the type -> type factory -> default value + """ + if type_ in self.types: + return TypeInstance(self.types[type_]) + + if type_ in self.factories: + return TypeFactory(self.factories[type_]) + + return default + + def merge(self, other: "Scope") -> "Scope": + """ + Merges two scopes together and returns the result as the new Scope instance + """ + new_scope = Scope() + new_scope.values = {**self.values, **other.values} + new_scope.types = {**self.types, **other.types} + new_scope.factories = { + type_: factory + for type_, factory in chain(self.factories.items(), other.factories.items()) + if type_ not in new_scope.types + } + + return new_scope + + def copy(self) -> "Scope": + """ + Make a copy of this scope + """ + return Scope(self.simplify()) + + def simplify(self): + """ + Return simple representation of this scope that can be used in the Scope constructor + """ + return ( + self.values + | {t: TypeInstance(ti) for t, ti in self.types.items()} + | {t: TypeFactory(f) for t, f in self.factories.items()} + ) + + @classmethod + def from_legacy(cls: typing.Callable[[], S], scope: Mapping[str, typing.Any]) -> S: + new_scope = cls() + + for key, value in scope.items(): + new_scope.values[key] = value + + value_type = type(value) + if value_type in IGNORE_TYPES or getattr(value_type, "__module__", None) == "builtins": + continue + + new_scope.types[value_type] = value + + return new_scope + + __or__ = merge + + @override + def __str__(self): + return f"Scope{{named={len(self.values)}, by_type={len(self.types)}, factories={len(self.factories)}}}" diff --git a/fundi/types.py b/fundi/types.py index dfc9129..b482358 100644 --- a/fundi/types.py +++ b/fundi/types.py @@ -7,6 +7,9 @@ from fundi.logging import get_logger +if typing.TYPE_CHECKING: + from fundi.scope import Scope + __all__ = [ "R", "Parameter", @@ -32,7 +35,7 @@ class TypeResolver: annotation: type -ScopeHook = typing.Callable[[dict[str, typing.Any], "CallableInfo[typing.Any]"], typing.Any] +ScopeHook = typing.Callable[["Scope", "CallableInfo[typing.Any]"], typing.Any] @dataclass diff --git a/tests/scope/test_add_factory.py b/tests/scope/test_add_factory.py new file mode 100644 index 0000000..8f5e6ce --- /dev/null +++ b/tests/scope/test_add_factory.py @@ -0,0 +1,32 @@ +from fundi import scan +from typing import NewType +from fundi.scope import Scope + + +def test_default(): + class User: + pass + + def factory() -> User: + return User() + + scope = Scope() + scope.add_factory(User, factory) + + assert scope.factories == {User: scan(factory)} + + +def test_alias(): + + class User: + pass + + def factory() -> User: + return User() + + Actor = NewType("Actor", User) + + scope = Scope() + scope.add_factory(Actor, factory) + + assert scope.factories == {Actor: scan(factory)} diff --git a/tests/scope/test_add_type.py b/tests/scope/test_add_type.py new file mode 100644 index 0000000..7613276 --- /dev/null +++ b/tests/scope/test_add_type.py @@ -0,0 +1,45 @@ +from typing_extensions import NewType + +from fundi.scope import Scope + + +def test_default(): + class User: + pass + + user = User() + + scope = Scope() + scope.add_type(user) + + assert scope.types == {User: user} + + +def test_mro(): + class User: + pass + + class Admin(User): + pass + + admin = Admin() + + scope = Scope() + scope.add_type(admin, mro=True) + + assert scope.types == {User: admin, Admin: admin} + + +def test_alias(): + class User: + pass + + Actor = NewType("Actor", User) + print(Actor, type(Actor)) + + user = User() + + scope = Scope() + scope.add_type(Actor, user) + + assert scope.types == {Actor: user} diff --git a/tests/scope/test_add_value.py b/tests/scope/test_add_value.py new file mode 100644 index 0000000..18d5f63 --- /dev/null +++ b/tests/scope/test_add_value.py @@ -0,0 +1,20 @@ +from fundi.scope import Scope + + +def test_default(): + scope = Scope() + scope.add_value("name", "Kuyugama") + + assert scope.values == {"name": "Kuyugama"} + + assert scope.resolve_by_name("name") == "Kuyugama" + + +def test_replace(): + scope = Scope({"name": "Kuyu"}) + + assert scope.add_value("name", "Kuyugama") + + assert scope.values == {"name": "Kuyugama"} + + assert scope.resolve_by_name("name") == "Kuyugama" diff --git a/tests/scope/test_copy.py b/tests/scope/test_copy.py new file mode 100644 index 0000000..8bf7143 --- /dev/null +++ b/tests/scope/test_copy.py @@ -0,0 +1,18 @@ +from fundi import scan +from fundi.scope import TypeInstance, TypeFactory, Scope + + +def test_default(): + initial = { + "key": "value", + int: TypeInstance(2), # noqa: F821 + str: TypeFactory(scan(lambda: "string")), # noqa: F821 + "another key": "another value", + } + scope = Scope(initial) + + copy = scope.copy() + + assert copy.values == scope.values + assert copy.types == scope.types + assert copy.factories == scope.factories diff --git a/tests/scope/test_creation.py b/tests/scope/test_creation.py new file mode 100644 index 0000000..7c56ed9 --- /dev/null +++ b/tests/scope/test_creation.py @@ -0,0 +1,79 @@ +import pytest + +from fundi import scan +from fundi.exceptions import InvalidInitialValue +from fundi.scope import Scope, TypeInstance, TypeFactory + + +def test_fundamental_types_only(): + initial = { + "int": 1, + "float": 1.1, + "str": "", + "bytes": b"", + "bytearray": bytearray(), + "set": set(), + "dict": {}, + "list": [], + "tuple": (), + "bool": True, + "object": object(), + "type": object, + } + scope = Scope(initial) + + assert scope.values == initial + assert scope.types == {} + assert scope.factories == {} + + +def test_custom_type_without_marker(): + class User: + pass + + initial = {User: User()} + + with pytest.raises(InvalidInitialValue): + Scope(initial) + + +def test_custom_type_with_marker(): + class User: + pass + + initial = {User: TypeInstance(User())} + + scope = Scope(initial) + + assert scope.values == {} + assert scope.types == {User: initial[User].instance} + assert scope.factories == {} + + +def test_factory_without_marker(): + class User: + pass + + def factory() -> User: + return User() + + initial = {User: scan(factory)} + + with pytest.raises(InvalidInitialValue): + Scope(initial) + + +def test_factory_with_marker(): + class User: + pass + + def factory() -> User: + return User() + + initial = {User: TypeFactory(scan(factory))} + + scope = Scope(initial) + + assert scope.values == {} + assert scope.types == {} + assert scope.factories == {User: initial[User].factory} diff --git a/tests/scope/test_from_legacy.py b/tests/scope/test_from_legacy.py new file mode 100644 index 0000000..d55b316 --- /dev/null +++ b/tests/scope/test_from_legacy.py @@ -0,0 +1,35 @@ +from fundi.scope import Scope + + +def test_fundamental_types(): + legacy = { + "int": 1, + "float": 1.1, + "str": "", + "bytes": b"", + "bytearray": bytearray(), + "set": set(), + "dict": {}, + "list": [], + "tuple": (), + "bool": True, + "object": object(), + "type": object, + } + scope = Scope.from_legacy(legacy) + + assert scope.factories == {} + assert scope.types == {} + assert scope.values == legacy + + +def test_custom_types(): + class User: + pass + + legacy = {"user": User()} + scope = Scope.from_legacy(legacy) + + assert scope.values == legacy + assert scope.types == {User: legacy["user"]} + assert scope.factories == {} diff --git a/tests/scope/test_merge.py b/tests/scope/test_merge.py new file mode 100644 index 0000000..3edb48c --- /dev/null +++ b/tests/scope/test_merge.py @@ -0,0 +1,57 @@ +import typing +from typing_extensions import NewType + +from fundi import scan +from fundi.scope import Scope, TypeInstance, TypeFactory + + +def test_replace(): + class AClass: + pass + + class BClass: + pass + + def factory0(): + pass + + def factory1(): + pass + + initial: dict[str | type | NewType, typing.Any] = { + "key": "value", + AClass: TypeInstance(1), + BClass: TypeFactory(scan(factory0)), + } + + scope = Scope(initial) + scope1 = Scope( + {"key": "another value", AClass: TypeInstance(2), BClass: TypeFactory(scan(factory1))} + ) + + scope_merged = scope | scope1 + + assert scope_merged.values == {"key": "another value"} + assert scope_merged.types == {AClass: 2} + assert scope_merged.factories == {BClass: scan(factory1)} + + +def test_extend(): + class AClass: + pass + + class BClass: + pass + + def factory0(): + pass + + scope = Scope({"base_key": "initial value"}) + + scope1 = Scope({"key": "value", AClass: TypeInstance(1), BClass: TypeFactory(scan(factory0))}) + + scope_merged = scope | scope1 + + assert scope_merged.values == {"key": "value", "base_key": "initial value"} + assert scope_merged.types == {AClass: 1} + assert scope_merged.factories == {BClass: scan(factory0)} diff --git a/tests/scope/test_resolve_by_name.py b/tests/scope/test_resolve_by_name.py new file mode 100644 index 0000000..95ffa85 --- /dev/null +++ b/tests/scope/test_resolve_by_name.py @@ -0,0 +1,31 @@ +from fundi.scope import Scope, NO_VALUE + + +def test_valid_name(): + constant = 1 + + initial = { + "constant": constant, + "float": 1.1, + "str": "", + "bytes": b"", + "bytearray": bytearray(), + "set": set(), + "dict": {}, + "list": [], + "tuple": (), + "bool": True, + "object": object(), + "type": object, + } + scope = Scope(initial) + + assert scope.resolve_by_name("constant") is constant + + +def test_invalid_name(): + initial = {"name": -1} + + scope = Scope(initial) + + assert scope.resolve_by_name("invalid-name") is NO_VALUE diff --git a/tests/scope/test_resolve_by_type.py b/tests/scope/test_resolve_by_type.py new file mode 100644 index 0000000..157d9a1 --- /dev/null +++ b/tests/scope/test_resolve_by_type.py @@ -0,0 +1,73 @@ +from typing import NewType +from fundi import scan +from fundi.scope import Scope, NO_VALUE, TypeInstance, TypeFactory + + +def test_exact(): + class User: + pass + + initial = {User: TypeInstance(User())} + scope = Scope(initial) + + value = scope.resolve_by_type(User) + assert isinstance(value, TypeInstance) + assert value.instance is initial[User].instance + assert isinstance(value.instance, User) + + +def test_no_value(): + scope = Scope() + + assert scope.resolve_by_type(type) is NO_VALUE + + +def test_mro(): + class User: + pass + + class Admin(User): + pass + + admin = Admin() + + scope = Scope() + scope.add_type(admin, mro=True) + + value = scope.resolve_by_type(User) + assert isinstance(value, TypeInstance) + assert value.instance is admin + assert isinstance(value.instance, User) and isinstance(value.instance, Admin) + + +def test_alias(): + class User: + pass + + Actor = NewType("Actor", User) + + user = User() + + initial = {Actor: TypeInstance(user)} + scope = Scope(initial) + + value = scope.resolve_by_type(Actor) + assert isinstance(value, TypeInstance) + assert value.instance is user + assert isinstance(value.instance, User) + + +def test_factory(): + class User: + pass + + def factory() -> User: + return User() + + initial = {User: TypeFactory(scan(factory))} + scope = Scope(initial) + + value = scope.resolve_by_type(User) + + assert isinstance(value, TypeFactory) + assert value.factory.call is factory diff --git a/tests/scope/test_simplify.py b/tests/scope/test_simplify.py new file mode 100644 index 0000000..13acd0e --- /dev/null +++ b/tests/scope/test_simplify.py @@ -0,0 +1,15 @@ +from fundi import scan +from fundi.scope import Scope, TypeInstance, TypeFactory + + +def test_default(): + initial = { + "key": "value", + int: TypeInstance(2), + str: TypeFactory(scan(lambda: "string")), + "another key": "another value", + } + scope = Scope(initial) + + simplified = scope.simplify() + assert simplified == initial diff --git a/tests/scope/test_update.py b/tests/scope/test_update.py new file mode 100644 index 0000000..9adc965 --- /dev/null +++ b/tests/scope/test_update.py @@ -0,0 +1,30 @@ +from fundi import scan +from fundi.scope import Scope, TypeInstance, TypeFactory + + +def test_default(): + scope = Scope() + scope.update(some="value") + + assert scope.values == {"some": "value"} + assert scope.types == {} + assert scope.factories == {} + + +def test_type_instances(): + scope = Scope() + scope.update({int: TypeInstance(2)}) + + assert scope.values == {} + assert scope.types == {int: 2} + assert scope.factories == {} + + +def test_type_factories(): + scope = Scope() + factory = scan(lambda: 1) + scope.update({int: TypeFactory(factory)}) + + assert scope.values == {} + assert scope.types == {} + assert scope.factories == {int: factory} diff --git a/tests/util/test_resolve.py b/tests/util/test_resolve.py index 014e9fc..7b67700 100644 --- a/tests/util/test_resolve.py +++ b/tests/util/test_resolve.py @@ -1,3 +1,4 @@ +from fundi.scope import Scope from fundi import resolve, from_, scan, exceptions, FromType @@ -8,7 +9,7 @@ def dep(): def func(arg: int, arg1: str, arg2: None = from_(dep)): pass - for result in resolve({"arg": 1, "arg1": "value"}, scan(func), {}): + for result in resolve(Scope.from_legacy({"arg": 1, "arg1": "value"}), scan(func), {}): if not result.resolved: assert result.dependency is not None assert result.dependency.call is dep @@ -30,7 +31,7 @@ async def dep(): async def func(arg: int, arg1: str, arg2: None = from_(dep)): pass - for result in resolve({"arg": 1, "arg1": "value"}, scan(func), {}): + for result in resolve(Scope.from_legacy({"arg": 1, "arg1": "value"}), scan(func), {}): if not result.resolved: assert result.dependency is not None assert result.dependency.call is dep @@ -58,7 +59,9 @@ def func(arg: int, arg1: str, handler: from_(EventHandler)): event_handler = EventHandler() - for result in resolve({"arg": 1, "arg1": "value", "+1": event_handler}, scan(func), {}): + for result in resolve( + Scope.from_legacy({"arg": 1, "arg1": "value", "+1": event_handler}), scan(func), {} + ): assert result.parameter.name in ("arg", "arg1", "handler") if result.parameter.name == "arg1": @@ -85,7 +88,9 @@ def func(arg: int, arg1: str, handler: FromType[EventHandler]): event_handler = EventHandler() - for result in resolve({"arg": 1, "arg1": "value", "+1": event_handler}, scan(func), {}): + for result in resolve( + Scope.from_legacy({"arg": 1, "arg1": "value", "+1": event_handler}), scan(func), {} + ): assert result.parameter.name in ("arg", "arg1", "handler") if result.parameter.name == "arg1": @@ -103,7 +108,7 @@ def dep(): ... def func(arg: int = from_(dep)): ... - for result in resolve({}, scan(func), {}, override={dep: 2}): + for result in resolve(Scope(), scan(func), {}, override={dep: 2}): assert result.parameter.name == "arg" assert result.value == 2 @@ -116,7 +121,7 @@ def test_dep(): ... def func(arg: int = from_(dep)): ... - for result in resolve({}, scan(func), {}, override={dep: scan(test_dep)}): + for result in resolve(Scope(), scan(func), {}, override={dep: scan(test_dep)}): assert result.parameter.name == "arg" assert result.resolved is False @@ -129,7 +134,7 @@ def test_resolve_not_found(): def func(arg: int): ... try: - for result in resolve({}, scan(func), {}): + for result in resolve(Scope(), scan(func), {}): # This assertion would never evaluate under normal circumstances assert result is None except exceptions.ScopeValueNotFoundError as exc: