From 0d8a259b077e1330142ca35edbae3ad6c5c277dd Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 17 Sep 2023 19:56:00 +0800 Subject: [PATCH 1/3] Stable diffusion VAE fine tuning (backport AutoencoderKL and its config.yaml to taming-transformers) --- configs/finetune_vae.yaml | 43 +++++++ taming/models/vqgan.py | 238 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 281 insertions(+) create mode 100644 configs/finetune_vae.yaml diff --git a/configs/finetune_vae.yaml b/configs/finetune_vae.yaml new file mode 100644 index 00000000..da226c24 --- /dev/null +++ b/configs/finetune_vae.yaml @@ -0,0 +1,43 @@ +model: + base_learning_rate: 4.5e-6 + target: taming.models.vqgan.AutoencoderKL + params: + embed_dim: 4 + ckpt_path: "path/to/some/vae" + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1,2,4,4] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: False + disc_in_channels: 3 + disc_start: 10000 + disc_weight: 0.8 + #codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 1 + num_workers: 2 + train: + target: taming.data.custom.CustomTrain + params: + training_images_list_file: train_img.txt + size: 256 + validation: + target: taming.data.custom.CustomTest + params: + test_images_list_file: val_img.txt + size: 256 + diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py index a6950baa..1162271b 100644 --- a/taming/models/vqgan.py +++ b/taming/models/vqgan.py @@ -9,6 +9,244 @@ from taming.modules.vqvae.quantize import GumbelQuantize from taming.modules.vqvae.quantize import EMAVectorQuantizer + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + train_decoder_only=True, + #ema_decay=None, + #learn_logvar=False + ): + super().__init__() + #self.learn_logvar = learn_logvar + self.train_decoder_only = train_decoder_only + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) # factor 2: mean and variance + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + #self.use_ema = ema_decay is not None + #if self.use_ema: + # self.ema_decay = ema_decay + # assert 0. < ema_decay < 1. + # self.model_ema = LitEma(self, decay=ema_decay) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + if path.endswith(".safetensors"): + from safetensors import safe_open + with safe_open(path, framework="pt", device=0) as f: + sd = {k: f.get_tensor(k) for k in f.keys()} + else: sd = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + #@contextmanager + #def ema_scope(self, context=None): + # if self.use_ema: + # self.model_ema.store(self.parameters()) + # self.model_ema.copy_to(self) + # if context is not None: + # print(f"{context}: Switched to EMA weights") + # try: + # yield None + # finally: + # if self.use_ema: + # self.model_ema.restore(self.parameters()) + # if context is not None: + # print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + #if self.use_ema: + # self.model_ema(self) + pass + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(torch.zeros_like(inputs), inputs, reconstructions, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(torch.zeros_like(inputs), inputs, reconstructions, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + #with self.ema_scope(): + # log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(torch.zeros_like(inputs), inputs, reconstructions, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(torch.zeros_like(inputs), inputs, reconstructions, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters()) + if not self.train_decoder_only: + ae_params_list += list(self.encoder.parameters()) + list(self.quant_conv.parameters()) + #if self.learn_logvar: + # print(f"{self.__class__.__name__}: Learning logvar") + # ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + #if x.shape[1] > 3: + # # colorize with random projection + # assert xrec.shape[1] > 3 + # x = self.to_rgb(x) + # xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + #if log_ema or self.use_ema: + # with self.ema_scope(): + # xrec_ema, posterior_ema = self(x) + # if x.shape[1] > 3: + # # colorize with random projection + # assert xrec_ema.shape[1] > 3 + # xrec_ema = self.to_rgb(xrec_ema) + # log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + # log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + #def to_rgb(self, x): + # assert self.image_key == "segmentation" + # if not hasattr(self, "colorize"): + # self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + # x = F.conv2d(x, weight=self.colorize) + # x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + # return x + class VQModel(pl.LightningModule): def __init__(self, ddconfig, From efb20ebfd2fcc681da73aadb49f9132febad1c56 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Sep 2023 00:21:17 +0800 Subject: [PATCH 2/3] Colab notebook --- Finetune_VAE.ipynb | 176 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 Finetune_VAE.ipynb diff --git a/Finetune_VAE.ipynb b/Finetune_VAE.ipynb new file mode 100644 index 00000000..d75e937f --- /dev/null +++ b/Finetune_VAE.ipynb @@ -0,0 +1,176 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "**Install conda, taming-transformers and its dependencies**" + ], + "metadata": { + "id": "EZx3a2Dsc_wq" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -q condacolab\n", + "import condacolab\n", + "condacolab.install()\n", + "\n", + "!git clone https://github.com/rbbb/taming-transformers\n", + "%cd /content/taming-transformers\n", + "!git checkout stable-diff-vae-finetuning\n", + "\n", + "!conda env update -n taming-transformers -f environment.yaml\n", + "!conda run -n taming-transformers pip install -e ." + ], + "metadata": { + "id": "UYDCEq4iE_iF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Put some images in /content/src_img_here/**" + ], + "metadata": { + "id": "uyDtJvL-dOWr" + } + }, + { + "cell_type": "code", + "source": [ + "!mkdir /content/src_img_here/\n", + "#directory /0000/ from danbooru 2020 data containing 3.5k images\n", + "!unzip /content/drive/MyDrive/danbooru_tiny.zip -d /content/src_img_here/" + ], + "metadata": { + "id": "Zme4fnUSWwyo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Find all images, put their name in train_img.txt and val_img.txt (10 validation images)**" + ], + "metadata": { + "id": "S55xG8OEdYjI" + } + }, + { + "cell_type": "code", + "source": [ + "import glob\n", + "img_filenames = list(glob.glob(\"/content/src_img_here/**/*.*\", recursive=True))\n", + "img_filenames = [i for i in img_filenames if (i.endswith(\".png\") or i.endswith(\".jpg\") or i.endswith(\".jpeg\"))]\n", + "with open(\"/content/taming-transformers/train_img.txt\",\"w\") as f:\n", + " f.write(\"\\n\".join(img_filenames[10:]))\n", + "with open(\"/content/taming-transformers/val_img.txt\",\"w\") as f:\n", + " f.write(\"\\n\".join(img_filenames[0:10]))\n" + ], + "metadata": { + "id": "XvInztU0Lx6y" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Get a VAE from civitai, convert from safetensors to ckpt**" + ], + "metadata": { + "id": "z71bh_0vdoap" + } + }, + { + "cell_type": "code", + "source": [ + "!mkdir /content/src_vae_here/\n", + "!wget -O /content/src_vae_here/vae.safetensors https://civitai.com/api/download/models/88156 #Clear VAE\n", + "\n", + "!pip install safetensors torch\n", + "from safetensors import safe_open\n", + "import torch\n", + "vae = {}\n", + "with safe_open(\"/content/src_vae_here/vae.safetensors\", framework=\"pt\", device=0) as f:\n", + " vae = {k: f.get_tensor(k) for k in f.keys()}\n", + "torch.save({'state_dict':vae}, \"/content/src_vae_here/vae.pt\")" + ], + "metadata": { + "id": "kPgbHeCpW-Et" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Change the parameters in the /content/taming-transformers/configs/finetune_vae.yaml (name of VAE file and batch size)**" + ], + "metadata": { + "id": "jg7BpQ4ad5xG" + } + }, + { + "cell_type": "code", + "source": [ + "import yaml\n", + "with open('/content/taming-transformers/configs/finetune_vae.yaml','r') as file:\n", + " conf = yaml.load(file, Loader=yaml.FullLoader)\n", + "\n", + "conf['model']['params']['ckpt_path'] = \"\\\"/content/src_vae_here/vae.pt\\\"\"\n", + "conf['data']['params']['batch_size'] = \"5\"\n", + "\n", + "with open('/content/taming-transformers/configs/finetune_vae.yaml','w') as file:\n", + " yaml.dump(conf, file)\n", + "\n" + ], + "metadata": { + "id": "0AX9Q_6zNu1C" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Run**" + ], + "metadata": { + "id": "-GXysMjAeLuC" + } + }, + { + "cell_type": "code", + "source": [ + "%cd /content/taming-transformers\n", + "!conda run --no-capture-output -n taming-transformers python main.py --base configs/finetune_vae.yaml -t True --gpus 0, --accumulate_grad_batches 6" + ], + "metadata": { + "id": "GSOiQtWKYzeV" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 3484597a005b692f00b022ab6f3d071feba1b79b Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Sep 2023 00:28:29 +0800 Subject: [PATCH 3/3] Colab notebook --- Finetune_VAE.ipynb | 176 --------------------------------------------- 1 file changed, 176 deletions(-) delete mode 100644 Finetune_VAE.ipynb diff --git a/Finetune_VAE.ipynb b/Finetune_VAE.ipynb deleted file mode 100644 index d75e937f..00000000 --- a/Finetune_VAE.ipynb +++ /dev/null @@ -1,176 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "gpuType": "T4" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "**Install conda, taming-transformers and its dependencies**" - ], - "metadata": { - "id": "EZx3a2Dsc_wq" - } - }, - { - "cell_type": "code", - "source": [ - "!pip install -q condacolab\n", - "import condacolab\n", - "condacolab.install()\n", - "\n", - "!git clone https://github.com/rbbb/taming-transformers\n", - "%cd /content/taming-transformers\n", - "!git checkout stable-diff-vae-finetuning\n", - "\n", - "!conda env update -n taming-transformers -f environment.yaml\n", - "!conda run -n taming-transformers pip install -e ." - ], - "metadata": { - "id": "UYDCEq4iE_iF" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "**Put some images in /content/src_img_here/**" - ], - "metadata": { - "id": "uyDtJvL-dOWr" - } - }, - { - "cell_type": "code", - "source": [ - "!mkdir /content/src_img_here/\n", - "#directory /0000/ from danbooru 2020 data containing 3.5k images\n", - "!unzip /content/drive/MyDrive/danbooru_tiny.zip -d /content/src_img_here/" - ], - "metadata": { - "id": "Zme4fnUSWwyo" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "**Find all images, put their name in train_img.txt and val_img.txt (10 validation images)**" - ], - "metadata": { - "id": "S55xG8OEdYjI" - } - }, - { - "cell_type": "code", - "source": [ - "import glob\n", - "img_filenames = list(glob.glob(\"/content/src_img_here/**/*.*\", recursive=True))\n", - "img_filenames = [i for i in img_filenames if (i.endswith(\".png\") or i.endswith(\".jpg\") or i.endswith(\".jpeg\"))]\n", - "with open(\"/content/taming-transformers/train_img.txt\",\"w\") as f:\n", - " f.write(\"\\n\".join(img_filenames[10:]))\n", - "with open(\"/content/taming-transformers/val_img.txt\",\"w\") as f:\n", - " f.write(\"\\n\".join(img_filenames[0:10]))\n" - ], - "metadata": { - "id": "XvInztU0Lx6y" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "**Get a VAE from civitai, convert from safetensors to ckpt**" - ], - "metadata": { - "id": "z71bh_0vdoap" - } - }, - { - "cell_type": "code", - "source": [ - "!mkdir /content/src_vae_here/\n", - "!wget -O /content/src_vae_here/vae.safetensors https://civitai.com/api/download/models/88156 #Clear VAE\n", - "\n", - "!pip install safetensors torch\n", - "from safetensors import safe_open\n", - "import torch\n", - "vae = {}\n", - "with safe_open(\"/content/src_vae_here/vae.safetensors\", framework=\"pt\", device=0) as f:\n", - " vae = {k: f.get_tensor(k) for k in f.keys()}\n", - "torch.save({'state_dict':vae}, \"/content/src_vae_here/vae.pt\")" - ], - "metadata": { - "id": "kPgbHeCpW-Et" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "**Change the parameters in the /content/taming-transformers/configs/finetune_vae.yaml (name of VAE file and batch size)**" - ], - "metadata": { - "id": "jg7BpQ4ad5xG" - } - }, - { - "cell_type": "code", - "source": [ - "import yaml\n", - "with open('/content/taming-transformers/configs/finetune_vae.yaml','r') as file:\n", - " conf = yaml.load(file, Loader=yaml.FullLoader)\n", - "\n", - "conf['model']['params']['ckpt_path'] = \"\\\"/content/src_vae_here/vae.pt\\\"\"\n", - "conf['data']['params']['batch_size'] = \"5\"\n", - "\n", - "with open('/content/taming-transformers/configs/finetune_vae.yaml','w') as file:\n", - " yaml.dump(conf, file)\n", - "\n" - ], - "metadata": { - "id": "0AX9Q_6zNu1C" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "**Run**" - ], - "metadata": { - "id": "-GXysMjAeLuC" - } - }, - { - "cell_type": "code", - "source": [ - "%cd /content/taming-transformers\n", - "!conda run --no-capture-output -n taming-transformers python main.py --base configs/finetune_vae.yaml -t True --gpus 0, --accumulate_grad_batches 6" - ], - "metadata": { - "id": "GSOiQtWKYzeV" - }, - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file