diff --git a/main.py b/main.py index 7b4f94c5..4a191dc6 100644 --- a/main.py +++ b/main.py @@ -11,13 +11,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from pytorch_lightning.utilities.distributed import rank_zero_only -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - +from taming import get_obj_from_str, instantiate_from_config def get_parser(**parser_kwargs): def str2bool(v): diff --git a/taming/__init__.py b/taming/__init__.py new file mode 100644 index 00000000..ac572368 --- /dev/null +++ b/taming/__init__.py @@ -0,0 +1,13 @@ +import importlib + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) diff --git a/taming/data/__init__.py b/taming/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/models/__init__.py b/taming/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/models/cond_transformer.py b/taming/models/cond_transformer.py index 6e6869b0..f4bc583f 100644 --- a/taming/models/cond_transformer.py +++ b/taming/models/cond_transformer.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config +from taming import instantiate_from_config from taming.modules.util import SOSProvider diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py index 121d01fd..9d751d2d 100644 --- a/taming/models/vqgan.py +++ b/taming/models/vqgan.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config +from taming import instantiate_from_config from taming.modules.diffusionmodules.model import Encoder, Decoder from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer diff --git a/taming/modules/__init__.py b/taming/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/diffusionmodules/__init__.py b/taming/modules/diffusionmodules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/discriminator/__init__.py b/taming/modules/discriminator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/misc/__init__.py b/taming/modules/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/transformer/__init__.py b/taming/modules/transformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/vqvae/__init__.py b/taming/modules/vqvae/__init__.py new file mode 100644 index 00000000..e69de29b