Skip to content

Commit

Permalink
Fix rename conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Aug 20, 2024
2 parents 24c5916 + 52088bf commit 58ab320
Show file tree
Hide file tree
Showing 15 changed files with 1,689 additions and 40 deletions.
113 changes: 101 additions & 12 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

"""Logger for generated images."""

import gc
from math import ceil
from typing import List, Optional, Tuple, Union

import torch
from composer import Callback, Logger, State
from composer.core import TimeUnit, get_precision_context
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoModel, AutoTokenizer, CLIPTextModel


class LogDiffusionImages(Callback):
Expand All @@ -35,6 +37,9 @@ class LogDiffusionImages(Callback):
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
cache_dir: (str, optional): path for HF to cache files while downloading model
"""

def __init__(self,
Expand All @@ -45,14 +50,18 @@ def __init__(self,
guidance_scale: float = 0.0,
rescaled_guidance: Optional[float] = None,
seed: Optional[int] = 1138,
use_table: bool = False):
use_table: bool = False,
t5_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
cache_dir: Optional[str] = '/tmp/hf_files'):
self.prompts = prompts
self.size = (size, size) if isinstance(size, int) else size
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.cache_dir = cache_dir

# Batch prompts
batch_size = len(prompts) if batch_size is None else batch_size
Expand All @@ -62,6 +71,66 @@ def __init__(self,
start, end = i * batch_size, (i + 1) * batch_size
self.batched_prompts.append(prompts[start:end])

if t5_encoder is not None and clip_encoder is None or t5_encoder is None and clip_encoder is not None:
raise ValueError('Cannot specify only one of text encoder and CLIP encoder.')

self.precomputed_latents = False
self.batched_latents = []
if t5_encoder:
self.precomputed_latents = True
t5_tokenizer = AutoTokenizer.from_pretrained(t5_encoder, cache_dir=self.cache_dir, local_files_only=True)
clip_tokenizer = AutoTokenizer.from_pretrained(clip_encoder,
subfolder='tokenizer',
cache_dir=self.cache_dir,
local_files_only=True)

t5_model = AutoModel.from_pretrained(t5_encoder,
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).encoder.cuda().eval()
clip_model = CLIPTextModel.from_pretrained(clip_encoder,
subfolder='text_encoder',
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).cuda().eval()

for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
truncation=True,
return_tensors='pt')
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch['T5_LATENTS'] = t5_latents
latent_batch['CLIP_LATENTS'] = clip_latents
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
latent_batch['CLIP_POOLED'] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
del clip_model
gc.collect()
torch.cuda.empty_cache()

def eval_start(self, state: State, logger: Logger):
# Get the model object if it has been wrapped by DDP to access the image generation function.
if isinstance(state.model, DistributedDataParallel):
Expand All @@ -72,17 +141,37 @@ def eval_start(self, state: State, logger: Logger):
# Generate images
with get_precision_context(state.precision):
all_gen_images = []
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch['CLIP_POOLED'].cuda()
prompt_mask = batch['ATTENTION_MASK'].cuda()
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)

gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
else:
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
gen_images = torch.cat(all_gen_images)

# Log images to wandb
Expand Down
14 changes: 7 additions & 7 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransorm,
RandomCropBucketedAspectRatioTransorm, RandomCropSquare)
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
from diffusion.datasets.utils import make_streams
from diffusion.models.text_encoder import MultiTokenizer

Expand Down Expand Up @@ -94,8 +94,8 @@ def __init__(
self.image_key = image_key
self.caption_key = caption_key
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
if isinstance(self.crop, RandomCropBucketedAspectRatioTransorm):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransorm'
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
self.zero_dropped_captions = zero_dropped_captions

self.tokenizer = tokenizer
Expand All @@ -113,7 +113,7 @@ def __getitem__(self, index):
orig_w, orig_h = img.size

# Image transforms
if isinstance(self.crop, RandomCropBucketedAspectRatioTransorm):
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
elif self.crop is not None:
img, crop_top, crop_left = self.crop(img)
Expand Down Expand Up @@ -258,10 +258,10 @@ def build_streaming_image_caption_dataloader(
elif crop_type == 'random':
crop = RandomCropSquare(resize_size)
elif crop_type == 'aspect_ratio':
crop = RandomCropAspectRatioTransorm(resize_size, ar_bucket_boundaries) # type: ignore
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
elif crop_type == 'bucketed_aspect_ratio':
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
crop = RandomCropBucketedAspectRatioTransorm(resize_size) # type: ignore
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
else:
crop = None

Expand Down
4 changes: 2 additions & 2 deletions diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(self, img):
return img, c_top, c_left


class RandomCropAspectRatioTransorm:
class RandomCropAspectRatioTransform:
"""Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket.
Args:
Expand Down Expand Up @@ -114,7 +114,7 @@ def __call__(self, img):
return img, c_top, c_left


class RandomCropBucketedAspectRatioTransorm:
class RandomCropBucketedAspectRatioTransform:
"""Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket.
This transform requires the desired aspect ratio bucket to be specified manually in the call to the transform.
Expand Down
Loading

0 comments on commit 58ab320

Please sign in to comment.