Skip to content

Commit

Permalink
Allow masking padding tokens in cross attention layers (#94)
Browse files Browse the repository at this point in the history
* add padding attn mask to training

* remove squeeze

* torch tensorify

* handle sdxl

* encoder attn mask

* retry

* pad masking in generate() and pyright

* toggle pad masking with flag, add arg for token masks in generate()
  • Loading branch information
jazcollins authored Nov 16, 2023
1 parent 3122b81 commit 5decf2a
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 20 deletions.
19 changes: 12 additions & 7 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,22 @@ def __getitem__(self, index):
out['drop_caption_mask'] = 1.0

max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
tokenized_caption = self.tokenizer(caption,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt')['input_ids']
tokenizer_out = self.tokenizer(caption,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt')
if self.sdxl:
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption]
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenizer_out.input_ids]
tokenized_caption = torch.stack(tokenized_caption)
# Take union over both tokenizers padding masks
attention_masks = tokenizer_out.attention_mask
attention_mask = torch.logical_or(attention_masks[0], attention_masks[1]).to(attention_masks[0].dtype)
else:
tokenized_caption = tokenized_caption.squeeze()
tokenized_caption = tokenizer_out.input_ids.squeeze()
attention_mask = tokenizer_out.attention_mask
out['captions'] = tokenized_caption
out['attention_mask'] = attention_mask
return out


Expand Down
6 changes: 6 additions & 0 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def stable_diffusion_2(
loss_bins: Optional[List] = None,
precomputed_latents: bool = False,
encode_latents_in_fp16: bool = True,
mask_pad_tokens: bool = False,
fsdp: bool = True,
clip_qkv: Optional[float] = None,
):
Expand All @@ -67,6 +68,7 @@ def stable_diffusion_2(
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
be used. Default `None`.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False.
fsdp (bool): Whether to use FSDP. Defaults to True.
clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None.
"""
Expand Down Expand Up @@ -123,6 +125,7 @@ def stable_diffusion_2(
loss_bins=loss_bins,
precomputed_latents=precomputed_latents,
encode_latents_in_fp16=encode_latents_in_fp16,
mask_pad_tokens=mask_pad_tokens,
fsdp=fsdp,
)
if torch.cuda.is_available():
Expand Down Expand Up @@ -156,6 +159,7 @@ def stable_diffusion_xl(
loss_bins: Optional[List] = None,
precomputed_latents: bool = False,
encode_latents_in_fp16: bool = True,
mask_pad_tokens: bool = False,
fsdp: bool = True,
clip_qkv: Optional[float] = 6.0,
):
Expand Down Expand Up @@ -188,6 +192,7 @@ def stable_diffusion_xl(
[(0, 1)].
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False.
fsdp (bool): Whether to use FSDP. Defaults to True.
clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to 6.0. Improves stability
of training.
Expand Down Expand Up @@ -259,6 +264,7 @@ def stable_diffusion_xl(
loss_bins=loss_bins,
precomputed_latents=precomputed_latents,
encode_latents_in_fp16=encode_latents_in_fp16,
mask_pad_tokens=mask_pad_tokens,
fsdp=fsdp,
sdxl=True,
)
Expand Down
71 changes: 58 additions & 13 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class StableDiffusion(ComposerModel):
Default: `False`.
encode_latents_in_fp16 (bool): whether to encode latents in fp16.
Default: `False`.
mask_pad_tokens (bool): whether to mask pad tokens in cross attention.
Default: `False`.
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
"""

Expand All @@ -88,6 +90,7 @@ def __init__(self,
text_latents_key: str = 'caption_latents',
precomputed_latents: bool = False,
encode_latents_in_fp16: bool = False,
mask_pad_tokens: bool = False,
fsdp: bool = False,
sdxl: bool = False):
super().__init__()
Expand All @@ -103,6 +106,7 @@ def __init__(self,
self.image_key = image_key
self.image_latents_key = image_latents_key
self.precomputed_latents = precomputed_latents
self.mask_pad_tokens = mask_pad_tokens
self.sdxl = sdxl
if self.sdxl:
self.latent_scale = 0.13025
Expand Down Expand Up @@ -152,6 +156,7 @@ def __init__(self,
self.text_key = text_key
self.text_latents_key = text_latents_key
self.encode_latents_in_fp16 = encode_latents_in_fp16
self.mask_pad_tokens = mask_pad_tokens
# freeze text_encoder during diffusion training
self.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)
Expand Down Expand Up @@ -206,6 +211,12 @@ def forward(self, batch):
if pooled_conditioning is not None:
pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1)

# Attention mask if needed
if self.mask_pad_tokens and 'attention_mask' in batch.keys():
encoder_attention_mask = batch['attention_mask']
else:
encoder_attention_mask = None

# Sample the diffusion timesteps
timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device)
# Add noise to the inputs (forward diffusion)
Expand Down Expand Up @@ -234,7 +245,10 @@ def forward(self, batch):
added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids}

# Forward through the model
return self.unet(noised_latents, timesteps, conditioning,
return self.unet(noised_latents,
timesteps,
conditioning,
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps

def loss(self, outputs, batch):
Expand All @@ -252,6 +266,12 @@ def eval_forward(self, batch, outputs=None):
prompts = batch[self.text_key]
height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1]

# Attention mask if needed
if self.mask_pad_tokens and 'attention_mask' in batch.keys():
encoder_attention_mask = batch['attention_mask']
else:
encoder_attention_mask = None

# If SDXL, add eval-time micro-conditioning to batch
if self.sdxl:
device = self.unet.device
Expand All @@ -266,6 +286,7 @@ def eval_forward(self, batch, outputs=None):
generated_images = {}
for guidance_scale in self.val_guidance_scales:
gen_images = self.generate(tokenized_prompts=prompts,
tokenized_prompts_pad_mask=encoder_attention_mask,
height=height,
width=width,
guidance_scale=guidance_scale,
Expand Down Expand Up @@ -339,6 +360,8 @@ def generate(
negative_prompt: Optional[list] = None,
tokenized_prompts: Optional[torch.LongTensor] = None,
tokenized_negative_prompts: Optional[torch.LongTensor] = None,
tokenized_prompts_pad_mask: Optional[torch.LongTensor] = None,
tokenized_negative_prompts_pad_mask: Optional[torch.LongTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
height: Optional[int] = None,
Expand Down Expand Up @@ -369,6 +392,10 @@ def generate(
otherwise will be of shape [B, max_length]. Default: `None`.
tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative
prompts instead of string prompts. Default: `None`.
tokenized_prompts_pad_mask (torch.LongTensor): Optionally pass padding mask for
pre-tokenized prompts. Default `None`.
tokenized_negative_prompts_pad_mask (torch.LongTensor): Optionall pass padding mask for
pre-tokenized negative prompts. Default `None`.
prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead
of string prompts. If both prompt and prompt_embeds
are passed, prompt_embeds will be used. Default: `None`.
Expand Down Expand Up @@ -423,22 +450,24 @@ def generate(

do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore

text_embeddings, pooled_text_embeddings = self._prepare_text_embeddings(prompt, tokenized_prompts,
prompt_embeds, num_images_per_prompt)
text_embeddings, pooled_text_embeddings, pad_attn_mask = self._prepare_text_embeddings(
prompt, tokenized_prompts, tokenized_prompts_pad_mask, prompt_embeds, num_images_per_prompt)
batch_size = len(text_embeddings) # len prompts * num_images_per_prompt
# classifier free guidance + negative prompts
# negative prompt is given in place of the unconditional input in classifier free guidance
pooled_embeddings = None
pooled_embeddings, encoder_attn_mask = pooled_text_embeddings, pad_attn_mask
if do_classifier_free_guidance:
if not negative_prompt and not tokenized_negative_prompts and not negative_prompt_embeds and zero_out_negative_prompt:
# Negative prompt is empty and we want to zero it out
unconditional_embeddings = torch.zeros_like(text_embeddings)
pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None
uncond_pad_attn_mask = torch.zeros_like(pad_attn_mask) if pad_attn_mask is not None else None
else:
if not negative_prompt:
negative_prompt = [''] * (batch_size // num_images_per_prompt) # type: ignore
unconditional_embeddings, pooled_unconditional_embeddings = self._prepare_text_embeddings(
negative_prompt, tokenized_negative_prompts, negative_prompt_embeds, num_images_per_prompt)
unconditional_embeddings, pooled_unconditional_embeddings, uncond_pad_attn_mask = self._prepare_text_embeddings(
negative_prompt, tokenized_negative_prompts, tokenized_negative_prompts_pad_mask,
negative_prompt_embeds, num_images_per_prompt)

# concat uncond + prompt
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings])
Expand All @@ -448,6 +477,9 @@ def generate(
if self.sdxl:
pooled_embeddings = pooled_text_embeddings

if pad_attn_mask is not None:
encoder_attn_mask = torch.cat([uncond_pad_attn_mask, pad_attn_mask]) # type: ignore

# prepare for diffusion generation process
latents = torch.randn(
(batch_size, self.unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
Expand Down Expand Up @@ -488,6 +520,7 @@ def generate(
pred = self.unet(latent_model_input,
t,
encoder_hidden_states=text_embeddings,
encoder_attention_mask=encoder_attn_mask,
added_cond_kwargs=added_cond_kwargs).sample

if do_classifier_free_guidance:
Expand All @@ -510,20 +543,28 @@ def generate(
image = (image / 2 + 0.5).clamp(0, 1)
return image.detach() # (batch*num_images_per_prompt, channel, h, w)

def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num_images_per_prompt):
def _prepare_text_embeddings(self, prompt, tokenized_prompts, tokenized_pad_mask, prompt_embeds,
num_images_per_prompt):
"""Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt."""
device = self.text_encoder.device
pooled_text_embeddings = None
if prompt_embeds is None:
max_length = None if self.sdxl else self.tokenizer.model_max_length
if tokenized_prompts is None:
tokenized_prompts = self.tokenizer(prompt,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt').input_ids
tokenized_out = self.tokenizer(prompt,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt')
tokenized_prompts = tokenized_out.input_ids
if self.mask_pad_tokens:
tokenized_pad_mask = tokenized_out.attention_mask
if self.sdxl:
tokenized_prompts = torch.stack([tokenized_prompts[0], tokenized_prompts[1]], dim=1)
if self.mask_pad_tokens:
# For cross attention mask, take union of masks (want [B, 77])
tokenized_pad_mask = torch.logical_or(tokenized_pad_mask[0], tokenized_pad_mask[1]).to(
tokenized_pad_mask[0].dtype).to(device)
if self.sdxl:
text_embeddings, pooled_text_embeddings = self.text_encoder(
[tokenized_prompts[:, 0, :].to(device), tokenized_prompts[:, 1, :].to(device)]) # type: ignore
Expand All @@ -539,10 +580,14 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # type: ignore
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)

if tokenized_pad_mask is not None:
tokenized_pad_mask = tokenized_pad_mask.repeat(1, num_images_per_prompt, 1)
tokenized_pad_mask = tokenized_pad_mask.view(bs_embed * num_images_per_prompt, seq_len) # [B, 77]

if self.sdxl and pooled_text_embeddings is not None:
pooled_text_embeddings = pooled_text_embeddings.repeat(1, num_images_per_prompt)
pooled_text_embeddings = pooled_text_embeddings.view(bs_embed * num_images_per_prompt, -1)
return text_embeddings, pooled_text_embeddings
return text_embeddings, pooled_text_embeddings, tokenized_pad_mask


def _check_prompt_lenths(prompt, negative_prompt):
Expand Down

0 comments on commit 5decf2a

Please sign in to comment.