diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 6245fafe..b31a19b3 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -43,6 +43,8 @@ class StreamingImageCaptionDataset(StreamingDataset): image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. sdxl (bool): Whether or not we're training SDXL. Default: `False`. + zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``. + **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ @@ -60,6 +62,7 @@ def __init__( image_key: str = 'image', caption_key: str = 'caption', sdxl: bool = False, + zero_dropped_captions: bool = False, **streaming_kwargs, ) -> None: @@ -87,6 +90,7 @@ def __init__( self.image_size = image_size self.image_key = image_key self.caption_key = caption_key + self.zero_dropped_captions = zero_dropped_captions def __getitem__(self, index): sample = super().__getitem__(index) @@ -122,12 +126,17 @@ def __getitem__(self, index): # Caption if torch.rand(1) < self.caption_drop_prob: caption = '' + if self.zero_dropped_captions: + out['drop_caption_mask'] = 0.0 + else: + out['drop_caption_mask'] = 1.0 else: caption = sample[self.caption_key] if isinstance(caption, List) and self.caption_selection == 'first': caption = caption[0] if isinstance(caption, List) and self.caption_selection == 'random': caption = random.sample(caption, k=1)[0] + out['drop_caption_mask'] = 1.0 max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore tokenized_caption = self.tokenizer(caption, @@ -158,6 +167,7 @@ def build_streaming_image_caption_dataloader( image_key: str = 'image', caption_key: str = 'caption', rand_crop: bool = False, + zero_dropped_captions: bool = True, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -178,6 +188,7 @@ def build_streaming_image_caption_dataloader( image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``. + zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -240,6 +251,7 @@ def build_streaming_image_caption_dataloader( caption_key=caption_key, batch_size=batch_size, sdxl=sdxl, + zero_dropped_captions=zero_dropped_captions, **streaming_kwargs, ) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 0bd95706..f8e94954 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -196,6 +196,12 @@ def forward(self, batch): # Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515) latents *= self.latent_scale + # Zero dropped captions if needed + if 'drop_caption_mask' in batch.keys(): + conditioning *= batch['drop_caption_mask'].view(-1, 1, 1) + if pooled_conditioning is not None: + pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1) + # Sample the diffusion timesteps timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device) # Add noise to the inputs (forward diffusion)