From 1c129e74611dfea0930ba6f499ab1f328c9e5c96 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Mon, 9 Oct 2023 12:43:15 +0800 Subject: [PATCH 1/7] idk if this works, but here's a small clip implementation --- muse_maskgit_pytorch/dataset.py | 24 ++++++++--------- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 27 ++++++++++++------- .../trainers/maskgit_trainer.py | 27 ++++++++++++++----- setup.py | 1 + train_muse_maskgit.py | 26 ++++++++++++++++++ 5 files changed, 77 insertions(+), 28 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 34c95a0..753b351 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -180,14 +180,14 @@ def __getitem__(self, index): if self.using_taming: if self.embeds: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text else: if self.embeds: - return self.transform(image), input_ids[0], attn_mask[0], embed + return self.transform(image), input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text class URLTextDataset(ImageDataset): @@ -255,14 +255,14 @@ def __getitem__(self, index): if self.using_taming: if self.embeds: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text else: if self.embeds: - return self.transform(image), input_ids[0], attn_mask[0], embed + return self.transform(image), input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text class LocalTextImageDataset(Dataset): @@ -350,14 +350,14 @@ def __getitem__(self, index): attn_mask = encoded.attention_mask if self.using_taming: if self.embeds: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text else: if self.embeds: - return self.transform(image), input_ids[0], attn_mask[0], embed + return self.transform(image), input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text def get_directory_size(path): diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index e7615fd..2b3fd7b 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -169,6 +169,7 @@ def __init__( self_cond: bool = False, add_mask_id: bool = False, cache_path: PathLike = None, + use_clip=False, **kwargs, ): super().__init__() @@ -183,29 +184,35 @@ def __init__( self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs) self.norm = LayerNorm(dim) + self.use_clip = use_clip + self.dim_out = default(dim_out, num_tokens) self.to_logits = nn.Linear(dim, self.dim_out, bias=False) # text conditioning - t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path) - self.t5: T5EncoderModel = t5 - self.tokenizer: T5Tokenizer = tokenizer + if not use_clip: + t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path) + self.t5: T5EncoderModel = t5 + self.tokenizer: T5Tokenizer = tokenizer - self.t5.eval() + self.t5.eval() - text_embed_dim = get_encoded_dim(t5_name) + text_embed_dim = get_encoded_dim(t5_name) - self.text_embed_proj = ( - nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity() - ) + self.text_embed_proj = ( + nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity() + ) # optional self conditioning self.self_cond = self_cond self.self_cond_to_init_embed = FeedForward(dim) def encode_text(self, *args, **kwargs): - kwargs.update(tokenizer=self.tokenizer, t5=self.t5) - return t5_encode_text(*args, **kwargs) + if not self.use_clip: + kwargs.update(tokenizer=self.tokenizer, t5=self.t5) + return t5_encode_text(*args, **kwargs) + else: + print("Using clip instead, this function shouldn't be accessed") def forward_with_cond_scale(self, *args, cond_scale=3.0, return_embed=False, **kwargs): if cond_scale == 1: diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index cfce2f1..97b78e6 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -6,7 +6,6 @@ from diffusers.optimization import SchedulerType from ema_pytorch import EMA from omegaconf import OmegaConf -from PIL import Image from torch.optim import Optimizer from torch.utils.data import DataLoader from torchvision.utils import save_image @@ -24,6 +23,9 @@ xm = None met = None +import open_clip +import torchvision.transforms as transforms +from PIL import Image from tqdm import tqdm @@ -59,6 +61,7 @@ def __init__( validation_image_scale: float = 1.0, only_save_last_checkpoint=False, args=None, + clip=None, ): super().__init__( dataloader=dataloader, @@ -96,6 +99,9 @@ def __init__( self.optim: Optimizer = optimizer self.lr_scheduler: SchedulerType = scheduler + self.use_clip = True if clip is not None else False + self.clip_model = clip + self.use_ema = use_ema self.validation_prompts: List[str] = validation_prompts if use_ema: @@ -154,15 +160,24 @@ def train(self): # logs for epoch in range(self.current_step // len(self.dl), self.num_epochs): - for imgs, input_ids, attn_mask, text_embeds in iter(self.dl): + for imgs, input_ids, attn_mask, text_embeds, text in iter(self.dl): train_loss = 0.0 steps = int(self.steps.item()) - if not text_embeds: + if not self.use_clip: + if not text_embeds: + with torch.no_grad(): + text_embeds = t5_encode_text_from_encoded( + input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device + ) + else: + img_for_embed = transforms.ToPILImage(imgs) + + model, _, preprocess = self.clip_model + text = open_clip.tokenize(text) + with torch.no_grad(): - text_embeds = t5_encode_text_from_encoded( - input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device - ) + text_embeds = model.encode_text(text) with self.accelerator.accumulate(self.model), self.accelerator.autocast(): loss = self.model(imgs, text_embeds=text_embeds) diff --git a/setup.py b/setup.py index 7619c11..21839c9 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ "xformers>=0.0.20", "wandb", "bz2file", + "open_clip_torch", ], classifiers=[ "Development Status :: 4 - Beta", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 4d11423..9a7b3ab 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -9,6 +9,7 @@ import bz2file as bz2 import datasets import diffusers +import open_clip import torch import transformers from accelerate.utils import ProjectConfiguration @@ -124,6 +125,12 @@ def decompress_pickle(file): parser.add_argument("--heads", type=int, default=8, help="Attention heads") parser.add_argument("--ff_mult", type=int, default=4, help="Feed forward expansion factor") parser.add_argument("--t5_name", type=str, default="t5-small", help="Name of your t5 model") +parser.add_argument( + "--use_metaclip", + action="store_true", + default=False, + help="whether to use MetaClip instead of a T5", +) parser.add_argument("--cond_image_size", type=int, default=None, help="Conditional image size.") parser.add_argument( "--validation_prompt", @@ -480,6 +487,7 @@ class Arguments: heads: int = 8 ff_mult: int = 4 t5_name: str = "t5-small" + use_metaclip: bool = False mixed_precision: str = "no" cond_image_size: Optional[int] = None validation_prompt: str = "A photo of a dog" @@ -750,6 +758,7 @@ def main(): cache_path=args.cache_path, flash=flash, xformers=xformers, + use_clip=args.use_metaclip, ) # (2) pass your trained VAE and the base transformer to MaskGit @@ -987,6 +996,22 @@ def main(): args.batch_size, ) + if args.use_metaclip: + if args.mixed_precision == "no": + clip_precision = "fp32" + else: + clip_precision = args.mixed_precision + + clip = open_clip.create_model_and_transforms( + "ViT-B-32-quickgelu", + pretrained="metaclip/b32_400m.pt", + cache_dir=args.cache_path, + precision=clip_precision, + device=accelerator.device, + ) + else: + clip = None + # Create the trainer accelerator.wait_for_everyone() trainer = MaskGitTrainer( @@ -1017,6 +1042,7 @@ def main(): only_save_last_checkpoint=args.only_save_last_checkpoint, num_epochs=args.num_epochs, args=args, + clip=clip, ) # Prepare the trainer for distributed training From 15747635d74265658310872c0c558ab6d9572e6d Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Mon, 9 Oct 2023 13:15:44 +0800 Subject: [PATCH 2/7] forgot to normalise --- muse_maskgit_pytorch/trainers/maskgit_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 97b78e6..7d252eb 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -171,13 +171,12 @@ def train(self): input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device ) else: - img_for_embed = transforms.ToPILImage(imgs) - model, _, preprocess = self.clip_model text = open_clip.tokenize(text) with torch.no_grad(): text_embeds = model.encode_text(text) + text_embeds /= text_embeds.norm(dim=-1, keepdim=True) with self.accelerator.accumulate(self.model), self.accelerator.autocast(): loss = self.model(imgs, text_embeds=text_embeds) From 883c3ed79484820e3f693c104b35dcd0bdf75c57 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Mon, 9 Oct 2023 13:34:50 +0800 Subject: [PATCH 3/7] swap to a better clip model --- train_muse_maskgit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 9a7b3ab..b8efe29 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -1003,8 +1003,8 @@ def main(): clip_precision = args.mixed_precision clip = open_clip.create_model_and_transforms( - "ViT-B-32-quickgelu", - pretrained="metaclip/b32_400m.pt", + "ViT-L-14", + pretrained="metaclip/l14_400m.pt", cache_dir=args.cache_path, precision=clip_precision, device=accelerator.device, From de06cf0068cd2f4e0c965a8d7e3053e9e7622aca Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Mon, 9 Oct 2023 16:28:30 +0800 Subject: [PATCH 4/7] maybe fix it? --- muse_maskgit_pytorch/dataset.py | 67 ++++++++++++-------- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 1 + train_muse_maskgit.py | 7 +- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 753b351..431583f 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -167,16 +167,20 @@ def __getitem__(self, index): else: text = descriptions # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [] + attn_mask = [] if self.using_taming: if self.embeds: @@ -242,16 +246,20 @@ def __getitem__(self, index): else: text = descriptions # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [] + attn_mask = [] if self.using_taming: if self.embeds: @@ -338,16 +346,21 @@ def __getitem__(self, index): embed = self.embeds[index] # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) + + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [] + attn_mask = [] - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask if self.using_taming: if self.embeds: return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 2b3fd7b..0199b2c 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -185,6 +185,7 @@ def __init__( self.norm = LayerNorm(dim) self.use_clip = use_clip + self.tokenizer = None self.dim_out = default(dim_out, num_tokens) self.to_logits = nn.Linear(dim, self.dim_out, bias=False) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index b8efe29..c573a9d 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -824,6 +824,9 @@ def main(): else: embeds = [] + if args.use_metaclip: + transformer.tokenizer = None + # Create the dataset objects with accelerator.main_process_first(): if args.no_cache and args.train_data_dir: @@ -1003,8 +1006,8 @@ def main(): clip_precision = args.mixed_precision clip = open_clip.create_model_and_transforms( - "ViT-L-14", - pretrained="metaclip/l14_400m.pt", + "convnext_base_w", + pretrained="laion2b_s13b_b82k_augreg", cache_dir=args.cache_path, precision=clip_precision, device=accelerator.device, From 0811ef26de1afd13feeb971ce288233aa3e7c5fe Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:35:57 +0800 Subject: [PATCH 5/7] DONT MERGE swapping to HF clip --- muse_maskgit_pytorch/dataset.py | 12 ++++++------ muse_maskgit_pytorch/muse_maskgit_pytorch.py | 9 ++++++--- muse_maskgit_pytorch/trainers/maskgit_trainer.py | 11 ++++++++--- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 431583f..d56aef0 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -179,8 +179,8 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask else: - input_ids = [] - attn_mask = [] + input_ids = [0] + attn_mask = [0] if self.using_taming: if self.embeds: @@ -258,8 +258,8 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask else: - input_ids = [] - attn_mask = [] + input_ids = [0] + attn_mask = [0] if self.using_taming: if self.embeds: @@ -358,8 +358,8 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask else: - input_ids = [] - attn_mask = [] + input_ids = [0] + attn_mask = [0] if self.using_taming: if self.embeds: diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 0199b2c..b8fdf95 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -200,9 +200,12 @@ def __init__( text_embed_dim = get_encoded_dim(t5_name) - self.text_embed_proj = ( - nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity() - ) + else: + text_embed_dim = 640 + + self.text_embed_proj = ( + nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity() + ) # optional self conditioning self.self_cond = self_cond diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 7d252eb..59e5d81 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -93,7 +93,7 @@ def __init__( # maskgit maskgit.vae.requires_grad_(False) - maskgit.transformer.t5.requires_grad_(False) + self.model: MaskGit = maskgit self.optim: Optimizer = optimizer @@ -102,6 +102,9 @@ def __init__( self.use_clip = True if clip is not None else False self.clip_model = clip + if not self.use_clip: + maskgit.transformer.t5.requires_grad_(False) + self.use_ema = use_ema self.validation_prompts: List[str] = validation_prompts if use_ema: @@ -175,8 +178,10 @@ def train(self): text = open_clip.tokenize(text) with torch.no_grad(): - text_embeds = model.encode_text(text) - text_embeds /= text_embeds.norm(dim=-1, keepdim=True) + text_embeds = model.encode_text(text.to(self.accelerator.device)) + text_embeds = text_embeds.unsqueeze(2).to(self.accelerator.device) + print(text_embeds.shape) + print(imgs.shape) with self.accelerator.accumulate(self.model), self.accelerator.autocast(): loss = self.model(imgs, text_embeds=text_embeds) From 2070a6963daf4267fcd3ae2938b01352b27dce16 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Tue, 10 Oct 2023 17:03:39 +0800 Subject: [PATCH 6/7] FINALLY (CLIP WORKS) this took way too long, on the plus side, we can also do clip with more than 77 tokens --- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 2 +- .../trainers/maskgit_trainer.py | 54 ++++++++++++++++--- setup.py | 1 - train_muse_maskgit.py | 28 ++++------ 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index b8fdf95..0d44594 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -201,7 +201,7 @@ def __init__( text_embed_dim = get_encoded_dim(t5_name) else: - text_embed_dim = 640 + text_embed_dim = 512 self.text_embed_proj = ( nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity() diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 59e5d81..0389952 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -29,6 +29,21 @@ from tqdm import tqdm +def divide_string(string, parts): + # Determine the length of each substring + part_length = len(string) // parts + + # Divide the string into 'parts' number of substrings + substrings = [string[i : i + part_length] for i in range(0, len(string), part_length)] + + # If there are any leftover characters, add them to the last substring + if len(substrings) > parts: + substrings[-2] += substrings[-1] + substrings.pop() + + return substrings + + class MaskGitTrainer(BaseAcceleratedTrainer): def __init__( self, @@ -174,14 +189,37 @@ def train(self): input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device ) else: - model, _, preprocess = self.clip_model - text = open_clip.tokenize(text) - - with torch.no_grad(): - text_embeds = model.encode_text(text.to(self.accelerator.device)) - text_embeds = text_embeds.unsqueeze(2).to(self.accelerator.device) - print(text_embeds.shape) - print(imgs.shape) + clip_model, clip_tokenizer = self.clip_model + inputs = [token[1:-1] for token in clip_tokenizer(text, truncation=True).input_ids] + + inputs = torch.tensor(inputs, device=self.accelerator.device) + + max_embeddings_multiples = (inputs.shape[1] - 2) // (75 - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = inputs[:, i * (75 - 2) : (i + 1) * (75 - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = inputs[0, 0] + text_input_chunk[:, -1] = inputs[0, -1] + text_embedding = clip_model(text_input_chunk)[0] + + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeds = torch.concat(text_embeddings, axis=1).to(self.accelerator.device) + else: + text_embeds = clip_model(inputs)[0].to(self.accelerator.device) with self.accelerator.accumulate(self.model), self.accelerator.autocast(): loss = self.model(imgs, text_embeds=text_embeds) diff --git a/setup.py b/setup.py index 21839c9..7619c11 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,6 @@ "xformers>=0.0.20", "wandb", "bz2file", - "open_clip_torch", ], classifiers=[ "Development Status :: 4 - Beta", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index c573a9d..337b258 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -9,7 +9,6 @@ import bz2file as bz2 import datasets import diffusers -import open_clip import torch import transformers from accelerate.utils import ProjectConfiguration @@ -19,6 +18,7 @@ from rich import inspect from torch.optim import Optimizer from tqdm import tqdm +from transformers import AutoTokenizer, CLIPTextModel import wandb from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded @@ -126,7 +126,7 @@ def decompress_pickle(file): parser.add_argument("--ff_mult", type=int, default=4, help="Feed forward expansion factor") parser.add_argument("--t5_name", type=str, default="t5-small", help="Name of your t5 model") parser.add_argument( - "--use_metaclip", + "--use_clip", action="store_true", default=False, help="whether to use MetaClip instead of a T5", @@ -487,7 +487,7 @@ class Arguments: heads: int = 8 ff_mult: int = 4 t5_name: str = "t5-small" - use_metaclip: bool = False + use_clip: bool = False mixed_precision: str = "no" cond_image_size: Optional[int] = None validation_prompt: str = "A photo of a dog" @@ -758,7 +758,7 @@ def main(): cache_path=args.cache_path, flash=flash, xformers=xformers, - use_clip=args.use_metaclip, + use_clip=args.use_clip, ) # (2) pass your trained VAE and the base transformer to MaskGit @@ -824,7 +824,7 @@ def main(): else: embeds = [] - if args.use_metaclip: + if args.use_clip: transformer.tokenizer = None # Create the dataset objects @@ -999,19 +999,13 @@ def main(): args.batch_size, ) - if args.use_metaclip: - if args.mixed_precision == "no": - clip_precision = "fp32" - else: - clip_precision = args.mixed_precision - - clip = open_clip.create_model_and_transforms( - "convnext_base_w", - pretrained="laion2b_s13b_b82k_augreg", - cache_dir=args.cache_path, - precision=clip_precision, - device=accelerator.device, + if args.use_clip: + model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to( + accelerator.device ) + tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path) + + clip = (model, tokenizer) else: clip = None From 701fd851bb5a2df769e8d3ec400b32777f70c081 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Tue, 10 Oct 2023 17:45:53 +0800 Subject: [PATCH 7/7] someone a lot smarter than me will have to figure out negative prompting also now doing image generation sequentially because adding clip to inference is harder than adding it to training --- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 91 ++++++++++++++----- .../trainers/maskgit_trainer.py | 20 ++-- train_muse_maskgit.py | 24 +++-- 3 files changed, 92 insertions(+), 43 deletions(-) diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 0d44594..6566370 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -13,7 +13,7 @@ from einops import rearrange, repeat from torch import einsum, isnan, nn from tqdm.auto import tqdm -from transformers import T5EncoderModel, T5Tokenizer +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer from .attn import ein_attn, sdp_attn from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text @@ -417,6 +417,8 @@ def __init__( self, image_size, transformer: MaskGitTransformer, + clip: CLIPTextModel = None, + clip_tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast] = None, accelerator: Optional[Accelerator] = None, noise_schedule: Callable = cosine_schedule, token_critic: Optional[TokenCritic] = None, @@ -446,6 +448,9 @@ def __init__( self.resize_image_for_cond_image = exists(cond_image_size) self.cond_drop_prob = cond_drop_prob + self.clip = clip + self.clip_tokenizer = clip_tokenizer + self.transformer = transformer self.self_cond = transformer.self_cond if not self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens: @@ -524,22 +529,54 @@ def generate( cond_ids = None - text_embeds = self.transformer.encode_text(texts) - - demask_fn = self.transformer.forward_with_cond_scale - - # whether to use token critic for scores - use_token_critic = exists(self.token_critic) and not force_not_use_token_critic - + # whether to use token critic for scores if use_token_critic: token_critic_fn = self.token_critic.forward_with_cond_scale - # negative prompting, as in paper + if self.clip is not None and self.clip_tokenizer is not None: + clip_model = self.clip + clip_tokenizer = self.clip_tokenizer + print(texts) + inputs = [token[1:-1] for token in clip_tokenizer(texts, truncation=True).input_ids] + + inputs = torch.tensor(inputs, device=self.accelerator.device) + max_embeddings_multiples = (inputs.shape[1] - 2) // (75 - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = inputs[:, i * (75 - 2) : (i + 1) * (75 - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = inputs[0, 0] + text_input_chunk[:, -1] = inputs[0, -1] + text_embedding = clip_model(text_input_chunk)[0] + + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeds = torch.concat(text_embeddings, axis=1).to(self.accelerator.device) + else: + text_embeds = clip_model(inputs)[0].to(self.accelerator.device) + else: + text_embeds = self.transformer.encode_text(texts) - neg_text_embeds = None - if exists(negative_texts): - assert len(texts) == len(negative_texts) + demask_fn = self.transformer.forward_with_cond_scale + + # negative prompting, as in paper + + neg_text_embeds = None + if exists(negative_texts): + assert len(texts) == len(negative_texts) neg_text_embeds = self.transformer.encode_text(negative_texts) demask_fn = partial( @@ -547,11 +584,11 @@ def generate( neg_text_embeds=neg_text_embeds, ) - if use_token_critic: - token_critic_fn = partial( - self.token_critic.forward_with_neg_prompt, - neg_text_embeds=neg_text_embeds, - ) + if use_token_critic: + token_critic_fn = partial( + self.token_critic.forward_with_neg_prompt, + neg_text_embeds=neg_text_embeds, + ) if self.resize_image_for_cond_image: if cond_images is None: @@ -576,14 +613,18 @@ def generate( ids = ids.scatter(1, masked_indices, self.mask_id) - logits, embed = demask_fn( - ids, - text_embeds=text_embeds, - self_cond_embed=self_cond_embed, - conditioning_token_ids=cond_ids, - cond_scale=cond_scale, - return_embed=True, - ) + if self.clip is None: + logits, embed = demask_fn( + ids, + text_embeds=text_embeds, + self_cond_embed=self_cond_embed, + conditioning_token_ids=cond_ids, + cond_scale=cond_scale, + return_embed=True, + ) + else: + embed = text_embeds + logits = text_embeds self_cond_embed = embed if self.self_cond else None diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 0389952..fb29613 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -149,14 +149,17 @@ def save_validation_images( self.accelerator.print( f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}" ) - - images = self.model.generate( - validation_prompts, - cond_images=cond_image, - cond_scale=cond_scale, - temperature=temperature, - timesteps=timesteps, - ).to(self.accelerator.device) + images = [] + for text in validation_prompts: + images.append( + self.model.generate( + (text,), + cond_images=cond_image, + cond_scale=cond_scale, + temperature=temperature, + timesteps=timesteps, + ).to(self.accelerator.device) + ) save_dir = self.results_dir.joinpath("MaskGit") save_dir.mkdir(exist_ok=True, parents=True) @@ -189,6 +192,7 @@ def train(self): input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device ) else: + print(text) clip_model, clip_tokenizer = self.clip_model inputs = [token[1:-1] for token in clip_tokenizer(text, truncation=True).input_ids] diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 337b258..d0ed167 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -761,10 +761,24 @@ def main(): use_clip=args.use_clip, ) + if args.use_clip: + model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to( + accelerator.device + ) + tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path) + + clip = (model, tokenizer) + else: + model = None + tokenizer = None + clip = None + # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit( vae=vae, # vqgan vae transformer=transformer, # transformer + clip=model, + clip_tokenizer=tokenizer, accelerator=accelerator, # accelerator image_size=args.image_size, # image size cond_drop_prob=args.cond_drop_prob, # conditional dropout, for classifier free guidance @@ -999,16 +1013,6 @@ def main(): args.batch_size, ) - if args.use_clip: - model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to( - accelerator.device - ) - tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path) - - clip = (model, tokenizer) - else: - clip = None - # Create the trainer accelerator.wait_for_everyone() trainer = MaskGitTrainer(