Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nested instantiation #2519

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
62 changes: 57 additions & 5 deletions tests/torchtune/config/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@
from torchtune.modules import RMSNorm


class Spice:
__slots__ = ["heat_level"]

def __init__(self, heat_level):
self.heat_level = heat_level


class Food:
__slots__ = ["seed", "ingredient"]

def __init__(self, seed, ingredient):
self.seed = seed
self.ingredient = ingredient


class TestInstantiate:
@pytest.fixture
def config(self):
Expand Down Expand Up @@ -58,11 +73,6 @@ def test_instantiate_node(self, config, module):
assert isinstance(actual, RMSNorm)
assert self.get_dim(actual) == self.get_dim(expected)

with pytest.raises(
InstantiationError, match="Cannot instantiate specified object"
):
_ = _instantiate_node(config.a)

def test_instantiate(self, config, module):
actual = instantiate(config.test)
expected = module
Expand All @@ -73,6 +83,18 @@ def test_instantiate(self, config, module):
actual = instantiate(config.test, eps=1e-4)
assert actual.eps == expected.eps

# should raise error if _component_ is not specified
with pytest.raises(
InstantiationError, match="Cannot instantiate specified object"
):
_ = instantiate(config)

with pytest.raises(
ValueError,
match="instantiate only supports DictConfigs or dicts, got <class 'str'>",
):
_ = instantiate(config.a)

# Test passing in positional args
del config.test.dim
actual = instantiate(config.test, 3)
Expand All @@ -92,3 +114,33 @@ def test_tokenizer_config_with_null(self):

tokenizer = instantiate(config.tokenizer)
assert tokenizer.max_seq_len is None

def test_nested_instantiation(self) -> None:
s = dedent(
"""\
food:
_component_: Food
seed: 0
ingredient:
_component_: Spice
heat_level: 5
"""
)
config = OmegaConf.create(s)

# Test successful nested instantiation
food = instantiate(config.food)
assert food.seed == 0
assert isinstance(food.ingredient, Spice)
assert food.ingredient.heat_level == 5

# Test overriding parameters
food = instantiate(config.food, seed=42)
assert food.seed == 42
assert food.ingredient.heat_level == 5

# Test overriding parameters of nested config
food = instantiate(
config.food, ingredient={"_component_": "Spice", "heat_level": 10}
)
assert food.ingredient.heat_level == 10
157 changes: 106 additions & 51 deletions torchtune/config/_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,136 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

import inspect
import os
import sys
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, Union

from omegaconf import DictConfig, OmegaConf
from torchtune.config._errors import InstantiationError
from torchtune.config._utils import _get_component_from_path, _has_component
from torchtune.config._utils import _get_component_from_path


def _create_component(
_component_: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""Create an instance of a component with given arguments."""
return _component_(*args, **kwargs)


def _instantiate_node(node: Dict[str, Any], *args: Any) -> Any:
def _instantiate_node(
obj: Any,
*args: Any,
caller_globals: Optional[Dict[str, Any]] = None,
) -> Any:
"""
Creates the object specified in _component_ field with provided positional args
and kwargs already merged. Raises an InstantiationError if _component_ is not specified.
Instantiate a component from an object, recursively processing nested structures.

If the object is a dictionary with a '_component_' key, instantiate the component,
processing its arguments recursively. If it's a dictionary without '_component_' or a list,
process each item recursively. Otherwise, return the object unchanged.

Args:
obj (Any): Object to process (dict, list, or other).
*args (Any): Positional arguments for the component (used only at top level).
caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals.

Returns:
Any: Instantiated object or processed structure.
"""
if _has_component(node):
_component_ = _get_component_from_path(node.get("_component_"))
kwargs = {k: v for k, v in node.items() if k != "_component_"}
return _create_component(_component_, args, kwargs)
if isinstance(obj, dict) or isinstance(obj, DictConfig):
if "_component_" not in obj:
return {
k: _instantiate_node(v, caller_globals=caller_globals)
for k, v in obj.items()
}
else:
_component_ = _get_component_from_path(
obj["_component_"], caller_globals=caller_globals
)
kwargs = {
k: _instantiate_node(v, caller_globals=caller_globals)
for k, v in obj.items()
if k != "_component_"
}
return _create_component(_component_, args, kwargs)
elif isinstance(obj, list):
return [_instantiate_node(item, caller_globals=caller_globals) for item in obj]
else:
raise InstantiationError(
"Cannot instantiate specified object."
+ "\nMake sure you've specified a _component_ field with a valid dotpath."
)
return obj


def instantiate(
config: DictConfig,
config: Union[Dict[str, Any], DictConfig],
*args: Any,
caller_globals: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""
Given a DictConfig with a _component_ field specifying the object to instantiate and
additional fields for keyword arguments, create an instance of the specified object.
You can use this function to create the exact instance of a torchtune object you want
to use in your recipe using the specification from the config.
Instantiate a component from a configuration, recursively handling nested components.

This function also supports passing in positional args and keyword args within the
function call. These are automatically merged with the provided config, with keyword
args taking precedence.
Given a dict with a '_component_' field specifying the object to instantiate and
additional fields as keyword arguments, create an instance of the specified object.
Positional and keyword arguments passed in the call are merged with the config, with
keyword arguments taking precedence.

Based on Hydra's `instantiate` utility from Facebook Research:
https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/instantiate/_instantiate2.py#L148
Based on Hydra's `instantiate` utility.

Args:
config (DictConfig): a single field in the OmegaConf object parsed from the yaml file.
This is expected to have a _component_ field specifying the path of the object
to instantiate.
*args (Any): positional arguments to pass to the object to instantiate.
**kwargs (Any): keyword arguments to pass to the object to instantiate.

Examples:
>>> config.yaml:
>>> model:
>>> _component_: torchtune.models.llama2
>>> num_layers: 32
>>> num_heads: 32
>>> num_kv_heads: 32

>>> from torchtune import config
>>> vocab_size = 32000
>>> # Pass in vocab size as positional argument. Since it is positioned first
>>> # in llama2(), it must be specified first. Pass in other arguments as kwargs.
>>> # This will return an nn.Module directly for llama2 with specified args.
>>> model = config.instantiate(parsed_yaml.model, vocab_size, max_seq_len=4096, embed_dim=4096)
config (Union[Dict[str, Any], DictConfig]): Configuration with '_component_' and optional arguments.
*args (Any): Positional arguments for the component.
caller_globals (Optional[Dict[str, Any]]): Enable instantiating objects from caller's globals.
**kwargs (Any): Keyword arguments to override or add to the config.

Returns:
Any: the instantiated object.
Any: The instantiated object, or None if config is None.

Examples:
>>> class Spice:
>>> def __init__(self, heat_level):
>>> self.heat_level = heat_level
>>> class Food:
>>> def __init__(self, seed, ingredient):
>>> self.seed = seed
>>> self.ingredient = ingredient
>>> config = OmegaConf.create({
>>> '_component_': 'Food',
>>> 'seed': 0,
>>> 'ingredient': {'_component_': 'Spice', 'heat_level': 5}
>>> })
>>> food = instantiate(config, seed=42)
>>> print(food.seed) # 42
>>> print(food.ingredient.heat_level) # 5
>>> new_spice = {'_component_': 'Spice', 'heat_level': 10}
>>> food = instantiate(config, ingredient=new_spice)
>>> print(food.ingredient.heat_level) # 10

Raises:
ValueError: if config is not a DictConfig.
ValueError: If config is not a DictConfig.
InstantiationError: If the object to instantiate misses the '_component_' key.
"""

# Return None if config is None
if config is None:
return None
if not OmegaConf.is_dict(config):
raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}")

# Convert plain dict to DictConfig if necessary
if isinstance(config, dict):
config = OmegaConf.create(config)

elif not OmegaConf.is_dict(config):
raise ValueError(
f"instantiate only supports DictConfigs or dicts, got {type(config)}"
)

if "_component_" not in config:
raise InstantiationError(
"Cannot instantiate specified object."
+ "\nMake sure you've specified a _component_ field with a valid dotpath."
+ f"\nGot {config=}."
)

# Ensure local imports are able to be instantiated
if os.getcwd() not in sys.path:
Expand All @@ -109,4 +152,16 @@ def instantiate(
# Resolve all interpolations, or references to other fields within the same config
OmegaConf.resolve(config)

return _instantiate_node(OmegaConf.to_object(config), *args)
# caller → instantiate → _instantiate_node → _get_component_from_path
# To get the caller's globals, in case the the user is trying to instantiate some object from it,
# we step back (f_back) and get it, so `_get_component_from_path`` can use it.
if caller_globals is None:
current_frame = inspect.currentframe()
if current_frame and current_frame.f_back:
caller_globals = current_frame.f_back.f_globals

return _instantiate_node(
OmegaConf.to_container(config, resolve=True),
caller_globals=caller_globals,
*args,
)