Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit a27e7a6

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Implement config validation to find unused keys (#665)
Summary: Pull Request resolved: #665 Implement a `ClassyConfigDict` type which supports tracking reads and freezing the map (the latter is unused currently). Added it to `build_task` to catch cases where we don't use any keys passed by users. This will not catch all instances, like when some components do a deepcopy - we assume all the keys and sub-keys are read in such a situation Differential Revision: D25321360 fbshipit-source-id: ff71e61298baa6c30d0e4719ec5512a20fda955c
1 parent 8592b83 commit a27e7a6

File tree

8 files changed

+356
-7
lines changed

8 files changed

+356
-7
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .classy_config_dict import ClassyConfigDict
7+
from .config_error import ConfigError, ConfigUnusedKeysError
8+
9+
__all__ = ["ClassyConfigDict", "ConfigError", "ConfigUnusedKeysError"]
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
import json
8+
from collections.abc import MutableMapping, Mapping
9+
10+
from .config_error import ConfigUnusedKeysError
11+
12+
13+
class ClassyConfigDict(MutableMapping):
14+
"""Mapping which can be made immutable. Also supports tracking unused keys."""
15+
16+
def __init__(self, *args, **kwargs):
17+
"""Create a ClassyConfigDict.
18+
19+
Supports the same API as a dict and recursively converts all dicts to
20+
ClassyConfigDicts.
21+
"""
22+
23+
# NOTE: Another way to implement this would be to subclass dict, but since dict
24+
# is a built-in, it isn't treated like a regular MutableMapping, and calls like
25+
# func(**map) are handled mysteriously, probably interpreter dependent.
26+
# The downside with this implementation is that this isn't a full dict and is
27+
# just a mapping, which means some features like JSON serialization don't work
28+
29+
self._dict = dict(*args, **kwargs)
30+
self._frozen = False
31+
self._keys_read = set()
32+
for k, v in self._dict.items():
33+
self._dict[k] = self._from_dict(v)
34+
35+
@classmethod
36+
def _from_dict(cls, obj):
37+
"""Recursively convert all dicts inside obj to ClassyConfigDicts"""
38+
39+
if isinstance(obj, Mapping):
40+
obj = ClassyConfigDict({k: cls._from_dict(v) for k, v in obj.items()})
41+
elif isinstance(obj, (list, tuple)):
42+
# tuples are also converted to lists
43+
obj = [cls._from_dict(v) for v in obj]
44+
return obj
45+
46+
def to_dict(self):
47+
"""Return a vanilla Python dict, converting dicts recursively"""
48+
return self._to_dict(self)
49+
50+
@classmethod
51+
def _to_dict(cls, obj):
52+
"""Recursively convert obj to vanilla Python dicts"""
53+
if isinstance(obj, ClassyConfigDict):
54+
obj = {k: cls._to_dict(v) for k, v in obj.items()}
55+
elif isinstance(obj, (list, tuple)):
56+
# tuples are also converted to lists
57+
obj = [cls._to_dict(v) for v in obj]
58+
return obj
59+
60+
def keys(self):
61+
return self._dict.keys()
62+
63+
def items(self):
64+
self._keys_read.update(self._dict.keys())
65+
return self._dict.items()
66+
67+
def values(self):
68+
self._keys_read.update(self._dict.keys())
69+
return self._dict.values()
70+
71+
def pop(self, key, default=None):
72+
return self._dict.pop(key, default)
73+
74+
def popitem(self):
75+
return self._dict.popitem()
76+
77+
def clear(self):
78+
self._dict.clear()
79+
80+
def update(self, *args, **kwargs):
81+
if self._frozen:
82+
raise TypeError("Frozen ClassyConfigDicts do not support updates")
83+
self._dict.update(*args, **kwargs)
84+
85+
def setdefault(self, key, default=None):
86+
return self._dict.setdefault(key, default)
87+
88+
def __contains__(self, key):
89+
return key in self._dict
90+
91+
def __eq__(self, obj):
92+
return self._dict == obj
93+
94+
def __len__(self):
95+
return len(self._dict)
96+
97+
def __getitem__(self, key):
98+
self._keys_read.add(key)
99+
return self._dict.__getitem__(key)
100+
101+
def __iter__(self):
102+
return iter(self._dict)
103+
104+
def __str__(self):
105+
return json.dumps(self.to_dict(), indent=4)
106+
107+
def __repr__(self):
108+
return repr(self._dict)
109+
110+
def get(self, key, default=None):
111+
if key in self._dict.keys():
112+
self._keys_read.add(key)
113+
return self._dict.get(key, default)
114+
115+
def __copy__(self):
116+
ret = ClassyConfigDict()
117+
for key, value in self._dict.items():
118+
self._keys_read.add(key)
119+
ret._dict[key] = value
120+
121+
def copy(self):
122+
return self.__copy__()
123+
124+
def __deepcopy__(self, memo=None):
125+
# for deepcopies we mark all the keys and sub-keys as read
126+
ret = ClassyConfigDict()
127+
for key, value in self._dict.items():
128+
self._keys_read.add(key)
129+
ret._dict[key] = copy.deepcopy(value)
130+
return ret
131+
132+
def __setitem__(self, key, value):
133+
if self._frozen:
134+
raise TypeError("Frozen ClassyConfigDicts do not support assignment")
135+
if isinstance(value, dict) and not isinstance(value, ClassyConfigDict):
136+
value = ClassyConfigDict(value)
137+
self._dict.__setitem__(key, value)
138+
139+
def __delitem__(self, key):
140+
if self._frozen:
141+
raise TypeError("Frozen ClassyConfigDicts do not support key deletion")
142+
del self._dict[key]
143+
144+
def _freeze(self, obj):
145+
if isinstance(obj, Mapping):
146+
assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict"
147+
obj._frozen = True
148+
for value in obj.values():
149+
self._freeze(value)
150+
elif isinstance(obj, list):
151+
for value in obj:
152+
self._freeze(value)
153+
154+
def _reset_tracking(self, obj):
155+
if isinstance(obj, Mapping):
156+
assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict"
157+
obj._keys_read = set()
158+
for value in obj._dict.values():
159+
self._reset_tracking(value)
160+
elif isinstance(obj, list):
161+
for value in obj:
162+
self._reset_tracking(value)
163+
164+
def _unused_keys(self, obj):
165+
unused_keys = []
166+
if isinstance(obj, Mapping):
167+
assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict"
168+
unused_keys = [key for key in obj._dict.keys() if key not in obj._keys_read]
169+
for key, value in obj._dict.items():
170+
unused_keys += [
171+
f"{key}.{subkey}" for subkey in self._unused_keys(value)
172+
]
173+
elif isinstance(obj, list):
174+
for i, value in enumerate(obj):
175+
unused_keys += [f"{i}.{subkey}" for subkey in self._unused_keys(value)]
176+
return unused_keys
177+
178+
def freeze(self):
179+
"""Freeze the ClassyConfigDict to disallow mutations"""
180+
self._freeze(self)
181+
182+
def reset_tracking(self):
183+
"""Reset key tracking"""
184+
self._reset_tracking(self)
185+
186+
def unused_keys(self):
187+
"""Fetch all the unused keys"""
188+
return self._unused_keys(self)
189+
190+
def check_unused_keys(self):
191+
"""Raise if the config has unused keys"""
192+
unused_keys = self.unused_keys()
193+
if unused_keys:
194+
raise ConfigUnusedKeysError(unused_keys)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import List
7+
8+
9+
class ConfigError(Exception):
10+
pass
11+
12+
13+
class ConfigUnusedKeysError(ConfigError):
14+
def __init__(self, unused_keys: List[str]):
15+
self.unused_keys = unused_keys
16+
super().__init__(f"The following keys were unused: {self.unused_keys}")

classy_vision/optim/sgd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88

99
import torch.optim
10+
from classy_vision.configuration import ClassyConfigDict
1011

1112
from . import ClassyOptimizer, register_optimizer
1213

@@ -63,10 +64,11 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD":
6364
config.setdefault("weight_decay", 0.0)
6465
config.setdefault("nesterov", False)
6566
config.setdefault("use_larc", False)
66-
config.setdefault(
67-
"larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02}
68-
)
69-
67+
if config["use_larc"]:
68+
larc_config = ClassyConfigDict(clip=True, eps=1e-8, trust_coefficient=0.02)
69+
else:
70+
larc_config = None
71+
config.setdefault("larc_config", larc_config)
7072
assert (
7173
config["momentum"] >= 0.0
7274
and config["momentum"] < 1.0

classy_vision/tasks/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from pathlib import Path
88

9+
from classy_vision.configuration import ClassyConfigDict
910
from classy_vision.generic.registry_utils import import_all_modules
1011

1112
from .classy_task import ClassyTask
1213

13-
1414
FILE_ROOT = Path(__file__).parent
1515

1616

@@ -26,8 +26,13 @@ def build_task(config):
2626
"foo": "bar"}` will find a class that was registered as "my_task"
2727
(see :func:`register_task`) and call .from_config on it."""
2828

29+
config = ClassyConfigDict(config)
30+
2931
task = TASK_REGISTRY[config["name"]].from_config(config)
3032

33+
# at this stage all the configs keys should have been used
34+
config.check_unused_keys()
35+
3136
return task
3237

3338

classy_vision/tasks/classification_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
494494
Returns:
495495
A ClassificationTask instance.
496496
"""
497+
497498
test_only = config.get("test_only", False)
498499
if not test_only:
499500
# TODO Make distinction between epochs and phases in optimizer clear
@@ -1252,7 +1253,6 @@ def log_phase_end(self, tag):
12521253

12531254
def __repr__(self):
12541255
if hasattr(self, "_config"):
1255-
config = json.dumps(self._config, indent=4)
1256-
return f"{super().__repr__()} initialized with config:\n{config}"
1256+
return f"{super().__repr__()} initialized with config:\n{self._config}"
12571257

12581258
return super().__repr__()

0 commit comments

Comments
 (0)