Skip to content

Commit

Permalink
Full SDXL Model (#67)
Browse files Browse the repository at this point in the history
* random crop

* zero init trick

* add intentionally buggy clipping

* fix docstring and update diffusers version

* fix attention clipping, add to sdxl, fix xformers import when not installed

* big sdxl commit, no style check

* fix style and pyright

* print sdxl statement

* add sdxl logic to generate

* allow setting SDXLTextEncoder device

* sdxltextencoder edits

* split conditioning

* remove prints

* microconditioning and cleaning up comments

* fix style

* fix dropout dtype

* rm local streaming

* Update diffusion/datasets/image_caption.py

Co-authored-by: Landan Seguin <[email protected]>

* use RandomCrop, fix LogDiffusionImages bug

* have tokenizers pass dict output

* add to layers.py docs

* override prediction_type in inference_noise_scheulder

* Update diffusion/models/stable_diffusion.py

Co-authored-by: Landan Seguin <[email protected]>

* fix style

* log_diffusion_images.py fix

* pass tokenized prompts as batch_size x 2 x max_length shape

* stack tokenizer output to match

* fix negative prompt classifier free guidance

* _prepare_text_embeddings fix

* add negative_prompt_embeds to zero_out_negative_prompt check

---------

Co-authored-by: Landan Seguin <[email protected]>
  • Loading branch information
jazcollins and Landanjs authored Oct 4, 2023
1 parent ccf58bd commit 35f5a57
Show file tree
Hide file tree
Showing 7 changed files with 589 additions and 69 deletions.
14 changes: 9 additions & 5 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class LogDiffusionImages(Callback):
the text prompt, usually at the expense of lower image quality.
Default: ``0.0``.
text_key (str, optional): Key in the batch to use for text prompts. Default: ``'captions'``.
tokenized_prompts (torch.LongTensor, optional): Batch of pre-tokenized prompts
to use for evaluation. Default: ``None``.
tokenized_prompts (torch.LongTensor or List[torch.LongTensor], optional): Batch of pre-tokenized prompts
to use for evaluation. If SDXL, this will be a list of two pre-tokenized prompts Default: ``None``.
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
Expand Down Expand Up @@ -63,13 +63,17 @@ def eval_batch_end(self, state: State, logger: Logger):
model = state.model

if self.tokenized_prompts is None:
tokenized_prompts = [
self.tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True,
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]
self.tokenized_prompts = torch.cat(tokenized_prompts)
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device)
if model.sdxl:
self.tokenized_prompts = torch.stack([torch.cat(tp) for tp in self.tokenized_prompts
]) # [B, 2, max_length]
else:
self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore

# Generate images
with get_precision_context(state.precision):
Expand Down
86 changes: 75 additions & 11 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Streaming Image-Caption dataset."""

import logging
import random
from io import BytesIO
from typing import Callable, Dict, List, Optional, Sequence, Union
Expand All @@ -14,7 +15,10 @@
from torchvision import transforms
from transformers import AutoTokenizer

from diffusion.datasets.laion.transforms import LargestCenterSquare
from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare, RandomCropSquareReturnTransform
from diffusion.models.models import SDXLTokenizer

log = logging.getLogger(__name__)

# Disable PIL max image size limit
Image.MAX_IMAGE_PIXELS = None
Expand All @@ -29,13 +33,16 @@ class StreamingImageCaptionDataset(StreamingDataset):
remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``.
local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
caption_selection (str): If there are multiple captions, specifies how to select a single caption.
'first' selects the first caption in the list and 'random' selects a random caption in the list.
If there is only one caption, this argument is ignored. Default: ``'first'``.
transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``.
image_size (Optional[int]): The size to resize the image to. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

Expand All @@ -46,11 +53,13 @@ def __init__(
local: Optional[str] = None,
tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base',
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
caption_selection: str = 'first',
transform: Optional[Callable] = None,
image_size: Optional[int] = None,
image_key: str = 'image',
caption_key: str = 'caption',
sdxl: bool = False,
**streaming_kwargs,
) -> None:

Expand All @@ -65,8 +74,15 @@ def __init__(
raise ValueError(f'Invalid caption selection: {caption_selection}. Must be one of [random, first]')

self.transform = transform
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer')
self.sdxl = sdxl
if self.sdxl:
self.tokenizer = SDXLTokenizer(tokenizer_name_or_path)
self.sdxl_crop = RandomCropSquareReturnTransform(image_size)
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer')
self.sdxl_crop = None
self.caption_drop_prob = caption_drop_prob
self.microcond_drop_prob = microcond_drop_prob
self.caption_selection = caption_selection
self.image_size = image_size
self.image_key = image_key
Expand All @@ -81,6 +97,25 @@ def __getitem__(self, index):
img = Image.open(BytesIO(sample[self.image_key]))
if img.mode != 'RGB':
img = img.convert('RGB')

out = {}
# Image transforms
if self.sdxl and self.sdxl_crop:
img, crop_top, crop_left, image_height, image_width = self.sdxl_crop(img)
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])
out['cond_original_size'] = torch.tensor([image_width, image_height])
out['cond_target_size'] = torch.tensor([self.image_size, self.image_size])

# Microconditioning dropout as in Stability repo
# https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160
if torch.rand(1) < self.microcond_drop_prob:
out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0
if torch.rand(1) < self.microcond_drop_prob:
out['cond_original_size'] = out['cond_original_size'] * 0
if torch.rand(1) < self.microcond_drop_prob:
out['cond_target_size'] = out['cond_target_size'] * 0
else:
crop_top, crop_left, image_height, image_width = None, None, None, None
if self.transform is not None:
img = self.transform(img)

Expand All @@ -93,13 +128,21 @@ def __getitem__(self, index):
caption = caption[0]
if isinstance(caption, List) and self.caption_selection == 'random':
caption = random.sample(caption, k=1)[0]

max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
tokenized_caption = self.tokenizer(caption,
padding='max_length',
max_length=self.tokenizer.model_max_length,
max_length=max_length,
truncation=True,
return_tensors='pt')['input_ids'][0]

return {'image': img, 'captions': tokenized_caption}
return_tensors='pt')['input_ids']
if self.sdxl:
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption]
tokenized_caption = torch.stack(tokenized_caption)
else:
tokenized_caption = tokenized_caption.squeeze()
out['image'] = img
out['captions'] = tokenized_caption
return out


def build_streaming_image_caption_dataloader(
Expand All @@ -108,11 +151,13 @@ def build_streaming_image_caption_dataloader(
batch_size: int,
tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base',
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
resize_size: int = 256,
caption_selection: str = 'first',
transform: Optional[List[Callable]] = None,
image_key: str = 'image',
caption_key: str = 'caption',
rand_crop: bool = False,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -124,13 +169,15 @@ def build_streaming_image_caption_dataloader(
batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
resize_size (int): The size to resize the image to. Default: ``256``.
caption_selection (str): If there are multiple captions, specifies how to select a single caption.
'first' selects the first caption in the list and 'random' selects a random caption in the list.
If there is only one caption, this argument is ignored. Default: ``'first'``.
transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand All @@ -156,26 +203,43 @@ def build_streaming_image_caption_dataloader(
for r, l in zip(remote, local):
streams.append(Stream(remote=r, local=l))

# Infer SDXL from tokenizer path
if tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0':
log.info('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.')
sdxl = True
else:
sdxl = False

# Setup the transforms to apply
crop_transform = LargestCenterSquare(resize_size) if rand_crop else RandomCropSquare(resize_size)
if transform is None:
transform = [
LargestCenterSquare(resize_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1
]
if sdxl:
# Crop will return parameters so do separately
transform = [
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
else:
transform = [
crop_transform,
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1
]
transform = transforms.Compose(transform)
assert isinstance(transform, Callable)

dataset = StreamingImageCaptionDataset(
streams=streams,
tokenizer_name_or_path=tokenizer_name_or_path,
caption_drop_prob=caption_drop_prob,
microcond_drop_prob=microcond_drop_prob,
caption_selection=caption_selection,
transform=transform,
image_size=resize_size,
image_key=image_key,
caption_key=caption_key,
batch_size=batch_size,
sdxl=sdxl,
**streaming_kwargs,
)

Expand Down
37 changes: 36 additions & 1 deletion diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Transforms for the laion dataset."""
"""Transforms for the training and eval dataset."""

import torchvision.transforms as transforms
from torchvision.transforms import RandomCrop
from torchvision.transforms.functional import crop


class LargestCenterSquare:
Expand All @@ -19,3 +21,36 @@ def __call__(self, img):
# Then take a center crop to a square.
img = self.center_crop(img)
return img


class RandomCropSquare:
"""Randomly crop square of a PIL image."""

def __init__(self, size):
self.size = size
self.random_crop = RandomCrop(size)

def __call__(self, img):
# First, resize the image such that the smallest side is self.size while preserving aspect ratio.
img = transforms.functional.resize(img, self.size, antialias=True)
# Then take a center crop to a square & return crop params.
c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size))
img = crop(img, c_top, c_left, h, w)
return img


class RandomCropSquareReturnTransform:
"""Randomly crop square of a PIL image and return the crop parameters."""

def __init__(self, size):
self.size = size
self.random_crop = RandomCrop(size)

def __call__(self, img):
# First, resize the image such that the smallest side is self.size while preserving aspect ratio.
orig_w, orig_h = img.size
img = transforms.functional.resize(img, self.size, antialias=True)
# Then take a center crop to a square & return crop params.
c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size))
img = crop(img, c_top, c_left, h, w)
return img, c_top, c_left, orig_h, orig_w
Loading

0 comments on commit 35f5a57

Please sign in to comment.