From 4da6e9906510e25228f5e54402becd59f28be14e Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 19 Mar 2025 11:57:11 -0700 Subject: [PATCH 01/17] nested instantiation --- torchtune/config/_instantiate.py | 144 ++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 51 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index 3a4d4f635d..c1959b0fac 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import copy import os import sys @@ -11,7 +10,7 @@ from omegaconf import DictConfig, OmegaConf from torchtune.config._errors import InstantiationError -from torchtune.config._utils import _get_component_from_path, _has_component +from torchtune.config._utils import _get_component_from_path def _create_component( @@ -19,23 +18,70 @@ def _create_component( args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> Any: + """Create an instance of a component with given arguments.""" return _component_(*args, **kwargs) -def _instantiate_node(node: Dict[str, Any], *args: Any) -> Any: +def _instantiate_node(config_dict: Dict[str, Any], *args: Any) -> Any: """ - Creates the object specified in _component_ field with provided positional args - and kwargs already merged. Raises an InstantiationError if _component_ is not specified. + Instantiate a component from a config dictionary. + + If the dictionary has a '_component_' field, retrieve the component, process + any nested arguments, and create the object with the given positional args. + + Args: + config_dict (Dict[str, Any]): Config dictionary with '_component_' and arguments. + *args (Any): Positional arguments for the component. + + Returns: + Any: The instantiated object. + + Examples: + >>> class Spice: + >>> def __init__(self, heat_level): + >>> self.heat_level = heat_level + >>> class Food: + >>> def __init__(self, seed, ingredient): + >>> self.seed = seed + >>> self.ingredient = ingredient + >>> config_dict = {'_component_': 'Food', 'seed': 42, + >>> 'ingredient': {'_component_': 'Spice', 'heat_level': 5}} + >>> food = _instantiate_node(config_dict) + >>> print(food.seed) # 42 + >>> print(food.ingredient.heat_level) # 5 + + Raises: + InstantiationError: If '_component_' is missing. """ - if _has_component(node): - _component_ = _get_component_from_path(node.get("_component_")) - kwargs = {k: v for k, v in node.items() if k != "_component_"} + if "_component_" in config_dict: + _component_ = _get_component_from_path(config_dict["_component_"]) + kwargs = { + k: _instantiate_nested(v) + for k, v in config_dict.items() + if k != "_component_" + } return _create_component(_component_, args, kwargs) - else: - raise InstantiationError( - "Cannot instantiate specified object." - + "\nMake sure you've specified a _component_ field with a valid dotpath." - ) + raise InstantiationError("Cannot instantiate: '_component_' field is missing.") + + +def _instantiate_nested(obj: Any) -> Any: + """ + Processes dictionaries and lists to recursively instantiate any nested '_component_' fields. + + Args: + obj (Any): Object to process (dict, list, or other). + + Returns: + Any: Object with nested components instantiated. + """ + if isinstance(obj, dict): + if "_component_" in obj: + config = OmegaConf.create(obj) + return instantiate(config) + return {k: _instantiate_nested(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_instantiate_nested(item) for item in obj] + return obj def instantiate( @@ -44,54 +90,53 @@ def instantiate( **kwargs: Any, ) -> Any: """ - Given a DictConfig with a _component_ field specifying the object to instantiate and - additional fields for keyword arguments, create an instance of the specified object. - You can use this function to create the exact instance of a torchtune object you want - to use in your recipe using the specification from the config. + Instantiate a component from a configuration, recursively handling nested components. - This function also supports passing in positional args and keyword args within the - function call. These are automatically merged with the provided config, with keyword - args taking precedence. + Given a DictConfig with a '_component_' field specifying the object to instantiate and + additional fields as keyword arguments, create an instance of the specified object. + Positional and keyword arguments passed in the call are merged with the config, with + keyword arguments taking precedence. - Based on Hydra's `instantiate` utility from Facebook Research: - https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/instantiate/_instantiate2.py#L148 + Based on Hydra's `instantiate` utility. Args: - config (DictConfig): a single field in the OmegaConf object parsed from the yaml file. - This is expected to have a _component_ field specifying the path of the object - to instantiate. - *args (Any): positional arguments to pass to the object to instantiate. - **kwargs (Any): keyword arguments to pass to the object to instantiate. - - Examples: - >>> config.yaml: - >>> model: - >>> _component_: torchtune.models.llama2 - >>> num_layers: 32 - >>> num_heads: 32 - >>> num_kv_heads: 32 - - >>> from torchtune import config - >>> vocab_size = 32000 - >>> # Pass in vocab size as positional argument. Since it is positioned first - >>> # in llama2(), it must be specified first. Pass in other arguments as kwargs. - >>> # This will return an nn.Module directly for llama2 with specified args. - >>> model = config.instantiate(parsed_yaml.model, vocab_size, max_seq_len=4096, embed_dim=4096) + config (DictConfig): Configuration with '_component_' and optional arguments. + *args (Any): Positional arguments for the component. + **kwargs (Any): Keyword arguments to override or add to the config. Returns: - Any: the instantiated object. + Any: The instantiated object, or None if config is None. + + Examples: + >>> class Spice: + >>> def __init__(self, heat_level): + >>> self.heat_level = heat_level + >>> class Food: + >>> def __init__(self, seed, ingredient): + >>> self.seed = seed + >>> self.ingredient = ingredient + >>> config = OmegaConf.create({ + >>> '_component_': 'Food', + >>> 'seed': 0, + >>> 'ingredient': {'_component_': 'Spice', 'heat_level': 5} + >>> }) + >>> food = instantiate(config, seed=42) + >>> print(food.seed) # 42 + >>> print(food.ingredient.heat_level) # 5 + >>> new_spice = {'_component_': 'Spice', 'heat_level': 10} + >>> food = instantiate(config, ingredient=new_spice) + >>> print(food.ingredient.heat_level) # 10 Raises: - ValueError: if config is not a DictConfig. - """ + ValueError: If config is not a DictConfig. - # Return None if config is None + Note: Modifies sys.path to include the current working directory for local imports. + """ if config is None: return None if not OmegaConf.is_dict(config): raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}") - # Ensure local imports are able to be instantiated if os.getcwd() not in sys.path: sys.path.append(os.getcwd()) @@ -103,10 +148,7 @@ def instantiate( config = config_copy if kwargs: - # This overwrites any repeated fields in the config with kwargs config = OmegaConf.merge(config, kwargs) - # Resolve all interpolations, or references to other fields within the same config OmegaConf.resolve(config) - - return _instantiate_node(OmegaConf.to_object(config), *args) + return _instantiate_node(OmegaConf.to_container(config, resolve=True), *args) From ddad6f30ae19f2f32d917c53138ffc17cb4c7ddf Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 19 Mar 2025 12:19:57 -0700 Subject: [PATCH 02/17] add unit test --- tests/torchtune/config/test_instantiate.py | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py index 7bb1fe147d..d1dfe12c2b 100644 --- a/tests/torchtune/config/test_instantiate.py +++ b/tests/torchtune/config/test_instantiate.py @@ -92,3 +92,29 @@ def test_tokenizer_config_with_null(self): tokenizer = instantiate(config.tokenizer) assert tokenizer.max_seq_len is None + + def test_nested_instantiation(self) -> None: + class Foo: + def __init__(self, bar): + self.bar = bar + + def __call__(self, x): + return self.bar(x) + + class Bar: + def __call__(self, x): + return x + 1 + + s = dedent( + """\ + foo: + _component_: foo + bar: + _component_: bar + """ + ) + config = OmegaConf.create(s) + + foo = instantiate(config.foo) + output = foo(1) + assert output == 2, f"Foo should call bar and return 1+1. Got {output} instead." From 2b651dba3626353d6129410630878f6c0a7d7488 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 08:05:01 -0700 Subject: [PATCH 03/17] add comments --- torchtune/config/_instantiate.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index c1959b0fac..f67292a7b7 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -61,7 +61,11 @@ def _instantiate_node(config_dict: Dict[str, Any], *args: Any) -> Any: if k != "_component_" } return _create_component(_component_, args, kwargs) - raise InstantiationError("Cannot instantiate: '_component_' field is missing.") + raise InstantiationError( + "Cannot instantiate specified object." + + "\nMake sure you've specified a _component_ field with a valid dotpath." + + f"\nGot {config_dict=}." + ) def _instantiate_nested(obj: Any) -> Any: @@ -134,9 +138,11 @@ def instantiate( """ if config is None: return None + if not OmegaConf.is_dict(config): raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}") + # Ensure local imports are able to be instantiated if os.getcwd() not in sys.path: sys.path.append(os.getcwd()) @@ -148,7 +154,10 @@ def instantiate( config = config_copy if kwargs: + # This overwrites any repeated fields in the config with kwargs config = OmegaConf.merge(config, kwargs) + # Resolve all interpolations, or references to other fields within the same config OmegaConf.resolve(config) + return _instantiate_node(OmegaConf.to_container(config, resolve=True), *args) From 1f070a51ebe04cc3a424c2c232fae767a46e255a Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 08:27:28 -0700 Subject: [PATCH 04/17] update test --- tests/torchtune/config/test_instantiate.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py index d1dfe12c2b..5568f8ca12 100644 --- a/tests/torchtune/config/test_instantiate.py +++ b/tests/torchtune/config/test_instantiate.py @@ -108,13 +108,15 @@ def __call__(self, x): s = dedent( """\ foo: - _component_: foo + _component_: Foo bar: - _component_: bar + _component_: Bar """ ) config = OmegaConf.create(s) foo = instantiate(config.foo) output = foo(1) - assert output == 2, f"Foo should call bar and return 1+1. Got {output} instead." + assert ( + output == 2 + ), f"Foo should call bar and return 1+1. Got {output} instead for config {s}." From 138dc246cabe32d06770c2dc54cc241b1f8bb04c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 08:42:06 -0700 Subject: [PATCH 05/17] move to globals --- tests/torchtune/config/test_instantiate.py | 24 ++++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py index 5568f8ca12..c9b6421631 100644 --- a/tests/torchtune/config/test_instantiate.py +++ b/tests/torchtune/config/test_instantiate.py @@ -19,6 +19,19 @@ from torchtune.modules import RMSNorm +class Foo: + def __init__(self, bar): + self.bar = bar + + def __call__(self, x): + return self.bar(x) + + +class Bar: + def __call__(self, x): + return x + 1 + + class TestInstantiate: @pytest.fixture def config(self): @@ -94,17 +107,6 @@ def test_tokenizer_config_with_null(self): assert tokenizer.max_seq_len is None def test_nested_instantiation(self) -> None: - class Foo: - def __init__(self, bar): - self.bar = bar - - def __call__(self, x): - return self.bar(x) - - class Bar: - def __call__(self, x): - return x + 1 - s = dedent( """\ foo: From bdd9aa8038780c5f72b833cf9511bb7389287a6e Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 09:20:46 -0700 Subject: [PATCH 06/17] enable instantiation from globals --- torchtune/config/_instantiate.py | 42 +++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index f67292a7b7..1e0a5d0db9 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -4,9 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy + +import inspect import os import sys -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Optional, Tuple from omegaconf import DictConfig, OmegaConf from torchtune.config._errors import InstantiationError @@ -22,7 +24,11 @@ def _create_component( return _component_(*args, **kwargs) -def _instantiate_node(config_dict: Dict[str, Any], *args: Any) -> Any: +def _instantiate_node( + config_dict: Dict[str, Any], + caller_globals: Optional[Dict[str, Any]] = None, + *args: Any, +) -> Any: """ Instantiate a component from a config dictionary. @@ -31,6 +37,7 @@ def _instantiate_node(config_dict: Dict[str, Any], *args: Any) -> Any: Args: config_dict (Dict[str, Any]): Config dictionary with '_component_' and arguments. + caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals. *args (Any): Positional arguments for the component. Returns: @@ -54,9 +61,11 @@ def _instantiate_node(config_dict: Dict[str, Any], *args: Any) -> Any: InstantiationError: If '_component_' is missing. """ if "_component_" in config_dict: - _component_ = _get_component_from_path(config_dict["_component_"]) + _component_ = _get_component_from_path( + config_dict["_component_"], caller_globals=caller_globals + ) kwargs = { - k: _instantiate_nested(v) + k: _instantiate_nested(v, caller_globals) for k, v in config_dict.items() if k != "_component_" } @@ -68,12 +77,15 @@ def _instantiate_node(config_dict: Dict[str, Any], *args: Any) -> Any: ) -def _instantiate_nested(obj: Any) -> Any: +def _instantiate_nested( + obj: Any, caller_globals: Optional[Dict[str, Any]] = None +) -> Any: """ Processes dictionaries and lists to recursively instantiate any nested '_component_' fields. Args: obj (Any): Object to process (dict, list, or other). + caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals. Returns: Any: Object with nested components instantiated. @@ -81,10 +93,10 @@ def _instantiate_nested(obj: Any) -> Any: if isinstance(obj, dict): if "_component_" in obj: config = OmegaConf.create(obj) - return instantiate(config) - return {k: _instantiate_nested(v) for k, v in obj.items()} + return instantiate(config, caller_globals=caller_globals) + return {k: _instantiate_nested(v, caller_globals) for k, v in obj.items()} elif isinstance(obj, list): - return [_instantiate_nested(item) for item in obj] + return [_instantiate_nested(item, caller_globals) for item in obj] return obj @@ -160,4 +172,16 @@ def instantiate( # Resolve all interpolations, or references to other fields within the same config OmegaConf.resolve(config) - return _instantiate_node(OmegaConf.to_container(config, resolve=True), *args) + # Where this is called → instantiate → _instantiate_node → _get_component_from_path + # if the user is instantiating a local object, we have to step back (f_back) and get these globals + # so that `_get_component_from_path`` can use it. + caller_globals = None + current_frame = inspect.currentframe() + if current_frame and current_frame.f_back: + caller_globals = current_frame.f_back.f_globals + + return _instantiate_node( + OmegaConf.to_container(config, resolve=True), + caller_globals=caller_globals, + *args, + ) From 95285a42347027a1147efa2180bdfefb744e1060 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 09:32:22 -0700 Subject: [PATCH 07/17] change args order --- torchtune/config/_instantiate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index 1e0a5d0db9..e95a105a5f 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -26,8 +26,8 @@ def _create_component( def _instantiate_node( config_dict: Dict[str, Any], - caller_globals: Optional[Dict[str, Any]] = None, *args: Any, + caller_globals: Optional[Dict[str, Any]] = None, ) -> Any: """ Instantiate a component from a config dictionary. @@ -37,8 +37,8 @@ def _instantiate_node( Args: config_dict (Dict[str, Any]): Config dictionary with '_component_' and arguments. - caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals. *args (Any): Positional arguments for the component. + caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals. Returns: Any: The instantiated object. From 7bcc885ca58d394a65832a3901417ddf76baf1c2 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 10:33:17 -0700 Subject: [PATCH 08/17] and another one --- torchtune/config/_instantiate.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index e95a105a5f..4f42686189 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -8,7 +8,7 @@ import inspect import os import sys -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, Union from omegaconf import DictConfig, OmegaConf from torchtune.config._errors import InstantiationError @@ -92,8 +92,7 @@ def _instantiate_nested( """ if isinstance(obj, dict): if "_component_" in obj: - config = OmegaConf.create(obj) - return instantiate(config, caller_globals=caller_globals) + return instantiate(obj, caller_globals=caller_globals) return {k: _instantiate_nested(v, caller_globals) for k, v in obj.items()} elif isinstance(obj, list): return [_instantiate_nested(item, caller_globals) for item in obj] @@ -101,14 +100,14 @@ def _instantiate_nested( def instantiate( - config: DictConfig, + config: Union[Dict[str, Any], DictConfig], *args: Any, **kwargs: Any, ) -> Any: """ Instantiate a component from a configuration, recursively handling nested components. - Given a DictConfig with a '_component_' field specifying the object to instantiate and + Given a dict with a '_component_' field specifying the object to instantiate and additional fields as keyword arguments, create an instance of the specified object. Positional and keyword arguments passed in the call are merged with the config, with keyword arguments taking precedence. @@ -116,7 +115,7 @@ def instantiate( Based on Hydra's `instantiate` utility. Args: - config (DictConfig): Configuration with '_component_' and optional arguments. + config (Union[Dict[str, Any], DictConfig]): Configuration with '_component_' and optional arguments. *args (Any): Positional arguments for the component. **kwargs (Any): Keyword arguments to override or add to the config. @@ -151,8 +150,13 @@ def instantiate( if config is None: return None - if not OmegaConf.is_dict(config): - raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}") + # Convert plain dict to DictConfig if necessary + if isinstance(config, dict): + config = OmegaConf.create(config) + elif not OmegaConf.is_dict(config): + raise ValueError( + f"instantiate only supports DictConfigs or dicts, got {type(config)}" + ) # Ensure local imports are able to be instantiated if os.getcwd() not in sys.path: From a00c8533c9b884d5533d2171e5db2b577cce1e03 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 11:16:39 -0700 Subject: [PATCH 09/17] add caller_globals to instantiate --- torchtune/config/_instantiate.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index 4f42686189..eb74348d96 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -102,6 +102,7 @@ def _instantiate_nested( def instantiate( config: Union[Dict[str, Any], DictConfig], *args: Any, + caller_globals: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """ @@ -117,6 +118,7 @@ def instantiate( Args: config (Union[Dict[str, Any], DictConfig]): Configuration with '_component_' and optional arguments. *args (Any): Positional arguments for the component. + caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals. **kwargs (Any): Keyword arguments to override or add to the config. Returns: @@ -179,10 +181,10 @@ def instantiate( # Where this is called → instantiate → _instantiate_node → _get_component_from_path # if the user is instantiating a local object, we have to step back (f_back) and get these globals # so that `_get_component_from_path`` can use it. - caller_globals = None - current_frame = inspect.currentframe() - if current_frame and current_frame.f_back: - caller_globals = current_frame.f_back.f_globals + if caller_globals is None: + current_frame = inspect.currentframe() + if current_frame and current_frame.f_back: + caller_globals = current_frame.f_back.f_globals return _instantiate_node( OmegaConf.to_container(config, resolve=True), From 6977fc79ab7a24e8694273740ea7176652c286e0 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 11:23:55 -0700 Subject: [PATCH 10/17] update comment --- torchtune/config/_instantiate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index eb74348d96..0b98955c84 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -178,9 +178,10 @@ def instantiate( # Resolve all interpolations, or references to other fields within the same config OmegaConf.resolve(config) - # Where this is called → instantiate → _instantiate_node → _get_component_from_path - # if the user is instantiating a local object, we have to step back (f_back) and get these globals - # so that `_get_component_from_path`` can use it. + # caller → instantiate → _instantiate_node → _get_component_from_path + # To get the caller's globals, in case the the user is trying to instantiate some object from it, + # we step back (f_back) and get it, so `_get_component_from_path`` can use it. + # For nested instantiation, this will NOT be None anymore, meaning that we preserve the caller's globals if caller_globals is None: current_frame = inspect.currentframe() if current_frame and current_frame.f_back: From b3f1de421190e3facd3164ce560b4c9b30206953 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 13:20:03 -0700 Subject: [PATCH 11/17] merge fn --- torchtune/config/_instantiate.py | 85 +++++++++++--------------------- 1 file changed, 28 insertions(+), 57 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index 0b98955c84..41641c94b8 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -25,78 +25,44 @@ def _create_component( def _instantiate_node( - config_dict: Dict[str, Any], + obj: Any, *args: Any, caller_globals: Optional[Dict[str, Any]] = None, ) -> Any: """ - Instantiate a component from a config dictionary. + Instantiate a component from an object, recursively processing nested structures. - If the dictionary has a '_component_' field, retrieve the component, process - any nested arguments, and create the object with the given positional args. + If the object is a dictionary with a '_component_' key, instantiate the component, + processing its arguments recursively. If it's a dictionary without '_component_' or a list, + process each item recursively. Otherwise, return the object unchanged. Args: - config_dict (Dict[str, Any]): Config dictionary with '_component_' and arguments. - *args (Any): Positional arguments for the component. + obj (Any): Object to process (dict, list, or other). + *args (Any): Positional arguments for the component (used only at top level). caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals. Returns: - Any: The instantiated object. - - Examples: - >>> class Spice: - >>> def __init__(self, heat_level): - >>> self.heat_level = heat_level - >>> class Food: - >>> def __init__(self, seed, ingredient): - >>> self.seed = seed - >>> self.ingredient = ingredient - >>> config_dict = {'_component_': 'Food', 'seed': 42, - >>> 'ingredient': {'_component_': 'Spice', 'heat_level': 5}} - >>> food = _instantiate_node(config_dict) - >>> print(food.seed) # 42 - >>> print(food.ingredient.heat_level) # 5 - - Raises: - InstantiationError: If '_component_' is missing. + Any: Instantiated object or processed structure. """ - if "_component_" in config_dict: + if isinstance(obj, dict) and "_component_" in obj: _component_ = _get_component_from_path( - config_dict["_component_"], caller_globals=caller_globals + obj["_component_"], caller_globals=caller_globals ) kwargs = { - k: _instantiate_nested(v, caller_globals) - for k, v in config_dict.items() + k: _instantiate_node(v, caller_globals=caller_globals) + for k, v in obj.items() if k != "_component_" } return _create_component(_component_, args, kwargs) - raise InstantiationError( - "Cannot instantiate specified object." - + "\nMake sure you've specified a _component_ field with a valid dotpath." - + f"\nGot {config_dict=}." - ) - - -def _instantiate_nested( - obj: Any, caller_globals: Optional[Dict[str, Any]] = None -) -> Any: - """ - Processes dictionaries and lists to recursively instantiate any nested '_component_' fields. - - Args: - obj (Any): Object to process (dict, list, or other). - caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals. - - Returns: - Any: Object with nested components instantiated. - """ - if isinstance(obj, dict): - if "_component_" in obj: - return instantiate(obj, caller_globals=caller_globals) - return {k: _instantiate_nested(v, caller_globals) for k, v in obj.items()} + elif isinstance(obj, dict): + return { + k: _instantiate_node(v, caller_globals=caller_globals) + for k, v in obj.items() + } elif isinstance(obj, list): - return [_instantiate_nested(item, caller_globals) for item in obj] - return obj + return [_instantiate_node(item, caller_globals=caller_globals) for item in obj] + else: + return obj def instantiate( @@ -146,15 +112,21 @@ def instantiate( Raises: ValueError: If config is not a DictConfig. - - Note: Modifies sys.path to include the current working directory for local imports. + InstantiationError: If the object to instantiate misses the '_component_' key. """ if config is None: return None # Convert plain dict to DictConfig if necessary if isinstance(config, dict): + if "_component_" not in config: + raise InstantiationError( + "Cannot instantiate specified object." + + "\nMake sure you've specified a _component_ field with a valid dotpath." + + f"\nGot {config=}." + ) config = OmegaConf.create(config) + elif not OmegaConf.is_dict(config): raise ValueError( f"instantiate only supports DictConfigs or dicts, got {type(config)}" @@ -181,7 +153,6 @@ def instantiate( # caller → instantiate → _instantiate_node → _get_component_from_path # To get the caller's globals, in case the the user is trying to instantiate some object from it, # we step back (f_back) and get it, so `_get_component_from_path`` can use it. - # For nested instantiation, this will NOT be None anymore, meaning that we preserve the caller's globals if caller_globals is None: current_frame = inspect.currentframe() if current_frame and current_frame.f_back: From d276bad8d24cd40592017b5477615e7ecf9d03b5 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 13:58:38 -0700 Subject: [PATCH 12/17] DictConfig is not a dict --- torchtune/config/_instantiate.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index 41641c94b8..186093b001 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -44,21 +44,22 @@ def _instantiate_node( Returns: Any: Instantiated object or processed structure. """ - if isinstance(obj, dict) and "_component_" in obj: - _component_ = _get_component_from_path( - obj["_component_"], caller_globals=caller_globals - ) - kwargs = { - k: _instantiate_node(v, caller_globals=caller_globals) - for k, v in obj.items() - if k != "_component_" - } - return _create_component(_component_, args, kwargs) - elif isinstance(obj, dict): - return { - k: _instantiate_node(v, caller_globals=caller_globals) - for k, v in obj.items() - } + if isinstance(obj, dict) or isinstance(obj, DictConfig): + if "_component_" not in obj: + return { + k: _instantiate_node(v, caller_globals=caller_globals) + for k, v in obj.items() + } + else: + _component_ = _get_component_from_path( + obj["_component_"], caller_globals=caller_globals + ) + kwargs = { + k: _instantiate_node(v, caller_globals=caller_globals) + for k, v in obj.items() + if k != "_component_" + } + return _create_component(_component_, args, kwargs) elif isinstance(obj, list): return [_instantiate_node(item, caller_globals=caller_globals) for item in obj] else: From d59c77dcae63fd76d06509bd4801b2916ea7d130 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 14:18:42 -0700 Subject: [PATCH 13/17] improve testing --- tests/torchtune/config/test_instantiate.py | 65 ++++++++++++++-------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py index c9b6421631..f3445f2447 100644 --- a/tests/torchtune/config/test_instantiate.py +++ b/tests/torchtune/config/test_instantiate.py @@ -19,17 +19,19 @@ from torchtune.modules import RMSNorm -class Foo: - def __init__(self, bar): - self.bar = bar +class Spice: + __slots__ = ["heat_level"] - def __call__(self, x): - return self.bar(x) + def __init__(self, heat_level): + self.heat_level = heat_level -class Bar: - def __call__(self, x): - return x + 1 +class Food: + __slots__ = ["seed", "ingredient"] + + def __init__(self, seed, ingredient): + self.seed = seed + self.ingredient = ingredient class TestInstantiate: @@ -71,11 +73,6 @@ def test_instantiate_node(self, config, module): assert isinstance(actual, RMSNorm) assert self.get_dim(actual) == self.get_dim(expected) - with pytest.raises( - InstantiationError, match="Cannot instantiate specified object" - ): - _ = _instantiate_node(config.a) - def test_instantiate(self, config, module): actual = instantiate(config.test) expected = module @@ -91,6 +88,17 @@ def test_instantiate(self, config, module): actual = instantiate(config.test, 3) assert self.get_dim(actual) == 3 + # should raise error if _component_ is not specified + with pytest.raises( + InstantiationError, match="Cannot instantiate specified object" + ): + _ = instantiate(config) + + with pytest.raises( + InstantiationError, match="Cannot instantiate specified object" + ): + _ = instantiate(config.a) + def test_tokenizer_config_with_null(self): assets = Path(__file__).parent.parent.parent / "assets" s = dedent( @@ -109,16 +117,29 @@ def test_tokenizer_config_with_null(self): def test_nested_instantiation(self) -> None: s = dedent( """\ - foo: - _component_: Foo - bar: - _component_: Bar + food: + _component_: Food + seed: 0 + ingredient: + _component_: Spice + heat_level: 5 """ ) config = OmegaConf.create(s) - foo = instantiate(config.foo) - output = foo(1) - assert ( - output == 2 - ), f"Foo should call bar and return 1+1. Got {output} instead for config {s}." + # Test successful nested instantiation + food = instantiate(config.food) + assert food.seed == 0 + assert isinstance(food.ingredient, Spice) + assert food.ingredient.heat_level == 5 + + # Test overriding parameters + food = instantiate(config.food, seed=42) + assert food.seed == 42 + assert food.ingredient.heat_level == 5 + + # Test overriding parameters of nested config + food = instantiate( + config.food, ingredient={"_component_": "Spice", "heat_level": 10} + ) + assert food.ingredient.heat_level == 10 From 43b2a239c5b196f3389d098bd94421710120e4ac Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 14:38:32 -0700 Subject: [PATCH 14/17] and again --- tests/torchtune/config/test_instantiate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py index f3445f2447..47bacc2b86 100644 --- a/tests/torchtune/config/test_instantiate.py +++ b/tests/torchtune/config/test_instantiate.py @@ -83,11 +83,6 @@ def test_instantiate(self, config, module): actual = instantiate(config.test, eps=1e-4) assert actual.eps == expected.eps - # Test passing in positional args - del config.test.dim - actual = instantiate(config.test, 3) - assert self.get_dim(actual) == 3 - # should raise error if _component_ is not specified with pytest.raises( InstantiationError, match="Cannot instantiate specified object" @@ -99,6 +94,11 @@ def test_instantiate(self, config, module): ): _ = instantiate(config.a) + # Test passing in positional args + del config.test.dim + actual = instantiate(config.test, 3) + assert self.get_dim(actual) == 3 + def test_tokenizer_config_with_null(self): assets = Path(__file__).parent.parent.parent / "assets" s = dedent( From 47ad1cfbcf4fe4d33127e037a295ffbd7a6759fd Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 14:51:23 -0700 Subject: [PATCH 15/17] and again --- torchtune/config/_instantiate.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index 186093b001..c4d4e1b211 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -120,12 +120,6 @@ def instantiate( # Convert plain dict to DictConfig if necessary if isinstance(config, dict): - if "_component_" not in config: - raise InstantiationError( - "Cannot instantiate specified object." - + "\nMake sure you've specified a _component_ field with a valid dotpath." - + f"\nGot {config=}." - ) config = OmegaConf.create(config) elif not OmegaConf.is_dict(config): @@ -133,6 +127,13 @@ def instantiate( f"instantiate only supports DictConfigs or dicts, got {type(config)}" ) + if "_component_" not in config: + raise InstantiationError( + "Cannot instantiate specified object." + + "\nMake sure you've specified a _component_ field with a valid dotpath." + + f"\nGot {config=}." + ) + # Ensure local imports are able to be instantiated if os.getcwd() not in sys.path: sys.path.append(os.getcwd()) From 9ab086e0caacc2c621cd79611caea52867b4577f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 14:59:06 -0700 Subject: [PATCH 16/17] and another one --- tests/torchtune/config/test_instantiate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py index 47bacc2b86..27ea0d611b 100644 --- a/tests/torchtune/config/test_instantiate.py +++ b/tests/torchtune/config/test_instantiate.py @@ -90,7 +90,8 @@ def test_instantiate(self, config, module): _ = instantiate(config) with pytest.raises( - InstantiationError, match="Cannot instantiate specified object" + InstantiationError, + match="instantiate only supports DictConfigs or dicts, got ", ): _ = instantiate(config.a) From 8cdf14d9af31a487feff494e4aff69201e4e26d5 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 27 Mar 2025 15:06:57 -0700 Subject: [PATCH 17/17] and again --- tests/torchtune/config/test_instantiate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py index 27ea0d611b..15cd1112c6 100644 --- a/tests/torchtune/config/test_instantiate.py +++ b/tests/torchtune/config/test_instantiate.py @@ -90,7 +90,7 @@ def test_instantiate(self, config, module): _ = instantiate(config) with pytest.raises( - InstantiationError, + ValueError, match="instantiate only supports DictConfigs or dicts, got ", ): _ = instantiate(config.a)