-
Notifications
You must be signed in to change notification settings - Fork 565
/
Copy pathtest_instantiate.py
146 lines (120 loc) · 4.09 KB
/
test_instantiate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from textwrap import dedent
import pytest
from omegaconf import OmegaConf
from torchtune.config._errors import InstantiationError
from torchtune.config._instantiate import (
_create_component,
_instantiate_node,
instantiate,
)
from torchtune.config._utils import _has_component
from torchtune.modules import RMSNorm
class Spice:
__slots__ = ["heat_level"]
def __init__(self, heat_level):
self.heat_level = heat_level
class Food:
__slots__ = ["seed", "ingredient"]
def __init__(self, seed, ingredient):
self.seed = seed
self.ingredient = ingredient
class TestInstantiate:
@pytest.fixture
def config(self):
s = """
a: b
b: c
test:
_component_: torchtune.modules.RMSNorm
dim: 5
"""
return OmegaConf.create(s)
@pytest.fixture
def module(self):
return RMSNorm(dim=5, eps=1e-4)
def get_dim(self, rms_norm: RMSNorm):
return rms_norm.scale.shape[0]
def test_has_path(self, config):
assert _has_component(config.test)
assert not _has_component(config.a)
def test_call_object(self, module):
obj = RMSNorm
args = (5,)
kwargs = {"eps": 1e-4}
actual = _create_component(obj, args, kwargs)
expected = module
assert isinstance(actual, RMSNorm)
assert self.get_dim(actual) == self.get_dim(expected)
assert actual.eps == expected.eps
def test_instantiate_node(self, config, module):
actual = _instantiate_node(config.test)
expected = module
assert isinstance(actual, RMSNorm)
assert self.get_dim(actual) == self.get_dim(expected)
def test_instantiate(self, config, module):
actual = instantiate(config.test)
expected = module
assert isinstance(actual, RMSNorm)
assert self.get_dim(actual) == self.get_dim(expected)
# Test passing in kwargs
actual = instantiate(config.test, eps=1e-4)
assert actual.eps == expected.eps
# should raise error if _component_ is not specified
with pytest.raises(
InstantiationError, match="Cannot instantiate specified object"
):
_ = instantiate(config)
with pytest.raises(
ValueError,
match="instantiate only supports DictConfigs or dicts, got <class 'str'>",
):
_ = instantiate(config.a)
# Test passing in positional args
del config.test.dim
actual = instantiate(config.test, 3)
assert self.get_dim(actual) == 3
def test_tokenizer_config_with_null(self):
assets = Path(__file__).parent.parent.parent / "assets"
s = dedent(
f"""\
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
max_seq_len: null
path: {assets / 'm.model'}
"""
)
config = OmegaConf.create(s)
tokenizer = instantiate(config.tokenizer)
assert tokenizer.max_seq_len is None
def test_nested_instantiation(self) -> None:
s = dedent(
"""\
food:
_component_: Food
seed: 0
ingredient:
_component_: Spice
heat_level: 5
"""
)
config = OmegaConf.create(s)
# Test successful nested instantiation
food = instantiate(config.food)
assert food.seed == 0
assert isinstance(food.ingredient, Spice)
assert food.ingredient.heat_level == 5
# Test overriding parameters
food = instantiate(config.food, seed=42)
assert food.seed == 42
assert food.ingredient.heat_level == 5
# Test overriding parameters of nested config
food = instantiate(
config.food, ingredient={"_component_": "Spice", "heat_level": 10}
)
assert food.ingredient.heat_level == 10