Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero dropped captions #77

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class StreamingImageCaptionDataset(StreamingDataset):
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``.

**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

Expand All @@ -60,6 +62,7 @@ def __init__(
image_key: str = 'image',
caption_key: str = 'caption',
sdxl: bool = False,
zero_dropped_captions: bool = False,
**streaming_kwargs,
) -> None:

Expand Down Expand Up @@ -87,6 +90,7 @@ def __init__(
self.image_size = image_size
self.image_key = image_key
self.caption_key = caption_key
self.zero_dropped_captions = zero_dropped_captions

def __getitem__(self, index):
sample = super().__getitem__(index)
Expand Down Expand Up @@ -122,12 +126,17 @@ def __getitem__(self, index):
# Caption
if torch.rand(1) < self.caption_drop_prob:
caption = ''
if self.zero_dropped_captions:
out['drop_caption_mask'] = 0.0
else:
out['drop_caption_mask'] = 1.0
else:
caption = sample[self.caption_key]
if isinstance(caption, List) and self.caption_selection == 'first':
caption = caption[0]
if isinstance(caption, List) and self.caption_selection == 'random':
caption = random.sample(caption, k=1)[0]
out['drop_caption_mask'] = 1.0

max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
tokenized_caption = self.tokenizer(caption,
Expand Down Expand Up @@ -158,6 +167,7 @@ def build_streaming_image_caption_dataloader(
image_key: str = 'image',
caption_key: str = 'caption',
rand_crop: bool = False,
zero_dropped_captions: bool = True,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -178,6 +188,7 @@ def build_streaming_image_caption_dataloader(
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand Down Expand Up @@ -240,6 +251,7 @@ def build_streaming_image_caption_dataloader(
caption_key=caption_key,
batch_size=batch_size,
sdxl=sdxl,
zero_dropped_captions=zero_dropped_captions,
**streaming_kwargs,
)

Expand Down
6 changes: 6 additions & 0 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ def forward(self, batch):
# Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515)
latents *= self.latent_scale

# Zero dropped captions if needed
if 'drop_caption_mask' in batch.keys():
conditioning *= batch['drop_caption_mask'].view(-1, 1, 1)
if pooled_conditioning is not None:
pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1)

# Sample the diffusion timesteps
timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device)
# Add noise to the inputs (forward diffusion)
Expand Down