Skip to content

Commit

Permalink
Zero dropped captions (#77)
Browse files Browse the repository at this point in the history
* zero out dropped captions

* add if statement for compatibility w other dataloders

* set zero_dropped_captions default to False
  • Loading branch information
jazcollins authored Oct 13, 2023
1 parent dc51257 commit 767db4f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
12 changes: 12 additions & 0 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class StreamingImageCaptionDataset(StreamingDataset):
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

Expand All @@ -60,6 +62,7 @@ def __init__(
image_key: str = 'image',
caption_key: str = 'caption',
sdxl: bool = False,
zero_dropped_captions: bool = False,
**streaming_kwargs,
) -> None:

Expand Down Expand Up @@ -87,6 +90,7 @@ def __init__(
self.image_size = image_size
self.image_key = image_key
self.caption_key = caption_key
self.zero_dropped_captions = zero_dropped_captions

def __getitem__(self, index):
sample = super().__getitem__(index)
Expand Down Expand Up @@ -122,12 +126,17 @@ def __getitem__(self, index):
# Caption
if torch.rand(1) < self.caption_drop_prob:
caption = ''
if self.zero_dropped_captions:
out['drop_caption_mask'] = 0.0
else:
out['drop_caption_mask'] = 1.0
else:
caption = sample[self.caption_key]
if isinstance(caption, List) and self.caption_selection == 'first':
caption = caption[0]
if isinstance(caption, List) and self.caption_selection == 'random':
caption = random.sample(caption, k=1)[0]
out['drop_caption_mask'] = 1.0

max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
tokenized_caption = self.tokenizer(caption,
Expand Down Expand Up @@ -158,6 +167,7 @@ def build_streaming_image_caption_dataloader(
image_key: str = 'image',
caption_key: str = 'caption',
rand_crop: bool = False,
zero_dropped_captions: bool = True,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -178,6 +188,7 @@ def build_streaming_image_caption_dataloader(
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand Down Expand Up @@ -240,6 +251,7 @@ def build_streaming_image_caption_dataloader(
caption_key=caption_key,
batch_size=batch_size,
sdxl=sdxl,
zero_dropped_captions=zero_dropped_captions,
**streaming_kwargs,
)

Expand Down
6 changes: 6 additions & 0 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ def forward(self, batch):
# Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515)
latents *= self.latent_scale

# Zero dropped captions if needed
if 'drop_caption_mask' in batch.keys():
conditioning *= batch['drop_caption_mask'].view(-1, 1, 1)
if pooled_conditioning is not None:
pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1)

# Sample the diffusion timesteps
timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device)
# Add noise to the inputs (forward diffusion)
Expand Down

0 comments on commit 767db4f

Please sign in to comment.