diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index fe7da9f5c8..f8b8647fd7 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -137,6 +137,8 @@ def _resolve_target( if full_key: msg += f"\nfull_key: {full_key}" raise InstantiationException(msg) from e + elif _is_target(target): # Recursive target resolve + target = instantiate_node(target) if not callable(target): msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" if full_key: diff --git a/tests/instantiate/__init__.py b/tests/instantiate/__init__.py index e4afaec733..d11a3f664f 100644 --- a/tests/instantiate/__init__.py +++ b/tests/instantiate/__init__.py @@ -85,6 +85,21 @@ def method() -> str: return "OuterClass.Nested.method return" +class ChainClass: + def __init__(self, a) -> None: + self.a = a + + def set_b(self, b) -> "ChainClass": + self.b = b + return self + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ChainClass): + return self.a == other.a and self.b == other.b + else: + return False + + def add_values(a: int, b: int) -> int: return a + b diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index f311271fba..924da7ac37 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -26,6 +26,7 @@ BClass, CenterCrop, CenterCropConf, + ChainClass, Compose, ComposeConf, IllegalType, @@ -254,6 +255,22 @@ def config(request: Any, src: Any) -> Any: 43, id="static_method", ), + # Check recursive + param( + { + "_target_":{ + "_target_": "builtins.getattr", + "_args_":[{ + "_target_": "tests.instantiate.ChainClass", + "a": 1 + }, "set_b"], + }, + 'b':2 + }, + {}, + ChainClass(1).set_b(2), + id="recursive_target", + ), # Check nested types and static methods param( {"_target_": "tests.instantiate.NestingClass"},