Skip to content

Commit

Permalink
[Enhance] Enhance config (#1232)
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE authored Jul 5, 2023
1 parent 33e30b7 commit 8d4bac2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 44 deletions.
92 changes: 51 additions & 41 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@
import re # type: ignore


def _lazy2string(cfg_dict, dict_type=None):
if isinstance(cfg_dict, dict):
dict_type = dict_type or type(cfg_dict)
return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)})
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(_lazy2string(v) for v in cfg_dict)
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
return f'{cfg_dict.module}.{str(cfg_dict)}'
else:
return cfg_dict


class ConfigDict(Dict):
"""A dictionary for config which has the same interface as python's built-
in dictionary and can be used as a normal dictionary.
Expand Down Expand Up @@ -249,7 +261,7 @@ def _merge_a_into_b(a, b):
for key, value in merged.items():
self[key] = value

def to_dict(self):
def _to_lazy_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""

Expand All @@ -268,6 +280,11 @@ def _to_dict(data):

return _to_dict(self)

def to_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""
return _lazy2string(self, dict_type=dict)


def add_args(parser: ArgumentParser,
cfg: dict,
Expand Down Expand Up @@ -318,6 +335,8 @@ class Config:
cfg_text (str, optional): Text of config. Defaults to None.
filename (str or Path, optional): Name of config file.
Defaults to None.
format_python_code (bool): Whether to format Python code by yapf.
Defaults to True.
Here is a simple example:
Expand Down Expand Up @@ -348,7 +367,8 @@ def __init__(self,
cfg_dict: dict = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None):
env_variables: Optional[dict] = None,
format_python_code: bool = True):
filename = str(filename) if isinstance(filename, Path) else filename
if cfg_dict is None:
cfg_dict = dict()
Expand All @@ -363,6 +383,7 @@ def __init__(self,
cfg_dict = ConfigDict(cfg_dict)
super().__setattr__('_cfg_dict', cfg_dict)
super().__setattr__('_filename', filename)
super().__setattr__('_format_python_code', format_python_code)
if cfg_text:
text = cfg_text
elif filename:
Expand All @@ -380,7 +401,8 @@ def fromfile(filename: Union[str, Path],
use_predefined_variables: bool = True,
import_custom_modules: bool = True,
use_environment_variables: bool = True,
lazy_import: Optional[bool] = None) -> 'Config':
lazy_import: Optional[bool] = None,
format_python_code: bool = True) -> 'Config':
"""Build a Config instance from config file.
Args:
Expand All @@ -392,6 +414,8 @@ def fromfile(filename: Union[str, Path],
lazy_import (bool): Whether to load config in `lazy_import` mode.
If it is `None`, it will be deduced by the content of the
config file. Defaults to None.
format_python_code (bool): Whether to format Python code by yapf.
Defaults to True.
Returns:
Config: Config instance built from config file.
Expand Down Expand Up @@ -434,13 +458,18 @@ def fromfile(filename: Union[str, Path],
raise e
finally:
ConfigDict.lazy = False
for key, value in list(cfg_dict.to_dict().items()):

# delete builtin imported objects
for key, value in list(cfg_dict._to_lazy_dict().items()):
if isinstance(value, (types.FunctionType, types.ModuleType)):
cfg_dict.pop(key)

# disable lazy import to get the real type. See more details about
# lazy in the docstring of ConfigDict
cfg = Config(cfg_dict, filename=filename)
cfg = Config(
cfg_dict,
filename=filename,
format_python_code=format_python_code)
object.__setattr__(cfg, '_imported_names', imported_names)
return cfg

Expand Down Expand Up @@ -1321,8 +1350,6 @@ def _indent(s_, num_spaces):
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = repr(v)
elif isinstance(v, (LazyObject, LazyAttr)):
v_str = f"'{v.module}.{str(v)}'"
else:
v_str = str(v)

Expand Down Expand Up @@ -1354,8 +1381,6 @@ def _format_list_tuple(k, v, use_mapping=False):
v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501
elif isinstance(item, str):
v_str += f'{_indent(repr(item), indent)},\n'
elif isinstance(item, (LazyObject, LazyAttr)):
v_str += f"'{str(item)}',\n"
else:
v_str += str(item) + ',\n'
if k is None:
Expand Down Expand Up @@ -1385,9 +1410,7 @@ def _format_dict(input_dict, outest_level=False):
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, (LazyObject, LazyAttr)):
attr_str = _format_basic_types(k, v, use_mapping) + end
elif isinstance(v, dict):
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
Expand All @@ -1406,19 +1429,20 @@ def _format_dict(input_dict, outest_level=False):
r += '}'
return r

cfg_dict = self._to_lazy_dict()
cfg_dict = self.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
try:
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
except: # noqa: E722
raise SyntaxError('Failed to format the config file, please '
f'check the syntax of: \n{text}')

if self._format_python_code:
# copied from setup.cfg
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
try:
text, _ = FormatCode(
text, style_config=yapf_style, verify=True)
except: # noqa: E722
raise SyntaxError('Failed to format the config file, please '
f'check the syntax of: \n{text}')
return text

def __repr__(self):
Expand Down Expand Up @@ -1490,7 +1514,7 @@ def dump(self, file: Optional[Union[str, Path]] = None):
str or None: Config text.
"""
file = str(file) if isinstance(file, Path) else file
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
cfg_dict = self.to_dict()
if file is None:
if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text
Expand Down Expand Up @@ -1594,7 +1618,7 @@ def _is_lazy_import(filename: str) -> bool:
def _to_lazy_dict(self, keep_imported: bool = False) -> dict:
"""Convert config object to dictionary and filter the imported
object."""
res = self._cfg_dict.to_dict()
res = self._cfg_dict._to_lazy_dict()
if hasattr(self, '_imported_names') and not keep_imported:
res = {
key: value
Expand All @@ -1613,21 +1637,7 @@ def to_dict(self, keep_imported: bool = False):
If you import third-party objects in the config file, all imported
objects will be converted to a string like ``torch.optim.SGD``
"""
_cfg_dict = self._to_lazy_dict(keep_imported)

def lazy2string(cfg_dict):
if isinstance(cfg_dict, dict):
return type(cfg_dict)(
{k: lazy2string(v)
for k, v in cfg_dict.items()})
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(lazy2string(v) for v in cfg_dict)
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
return f'{cfg_dict.module}.{str(cfg_dict)}'
else:
return cfg_dict

return lazy2string(_cfg_dict)
return self._cfg_dict.to_dict()


class DictAction(Action):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def test_build_lazy(self):
]))
cfg_dict = ConfigDict(raw)
# check `items` and values
self.assertDictEqual(cfg_dict.to_dict(), raw)
self.assertDictEqual(cfg_dict._to_lazy_dict(), raw)
self._check(cfg_dict)

# check getattr
Expand Down Expand Up @@ -1132,10 +1132,10 @@ def test_build_lazy(self):
def _check(self, cfg_dict):
self._recursive_check_lazy(cfg_dict,
lambda x: not isinstance(x, LazyObject))
self._recursive_check_lazy(cfg_dict.to_dict(),
self._recursive_check_lazy(cfg_dict._to_lazy_dict(),
lambda x: x is not mmengine)
self._recursive_check_lazy(
cfg_dict.to_dict(), lambda x: not isinstance(x, ConfigDict)
cfg_dict._to_lazy_dict(), lambda x: not isinstance(x, ConfigDict)
if isinstance(x, dict) else True)
self._recursive_check_lazy(
cfg_dict, lambda x: isinstance(x, ConfigDict)
Expand Down

0 comments on commit 8d4bac2

Please sign in to comment.