Skip to content

Commit 5decf2a

Browse files
authored
Allow masking padding tokens in cross attention layers (#94)
* add padding attn mask to training * remove squeeze * torch tensorify * handle sdxl * encoder attn mask * retry * pad masking in generate() and pyright * toggle pad masking with flag, add arg for token masks in generate()
1 parent 3122b81 commit 5decf2a

File tree

3 files changed

+76
-20
lines changed

3 files changed

+76
-20
lines changed

diffusion/datasets/image_caption.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,22 @@ def __getitem__(self, index):
151151
out['drop_caption_mask'] = 1.0
152152

153153
max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
154-
tokenized_caption = self.tokenizer(caption,
155-
padding='max_length',
156-
max_length=max_length,
157-
truncation=True,
158-
return_tensors='pt')['input_ids']
154+
tokenizer_out = self.tokenizer(caption,
155+
padding='max_length',
156+
max_length=max_length,
157+
truncation=True,
158+
return_tensors='pt')
159159
if self.sdxl:
160-
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption]
160+
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenizer_out.input_ids]
161161
tokenized_caption = torch.stack(tokenized_caption)
162+
# Take union over both tokenizers padding masks
163+
attention_masks = tokenizer_out.attention_mask
164+
attention_mask = torch.logical_or(attention_masks[0], attention_masks[1]).to(attention_masks[0].dtype)
162165
else:
163-
tokenized_caption = tokenized_caption.squeeze()
166+
tokenized_caption = tokenizer_out.input_ids.squeeze()
167+
attention_mask = tokenizer_out.attention_mask
164168
out['captions'] = tokenized_caption
169+
out['attention_mask'] = attention_mask
165170
return out
166171

167172

diffusion/models/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def stable_diffusion_2(
4141
loss_bins: Optional[List] = None,
4242
precomputed_latents: bool = False,
4343
encode_latents_in_fp16: bool = True,
44+
mask_pad_tokens: bool = False,
4445
fsdp: bool = True,
4546
clip_qkv: Optional[float] = None,
4647
):
@@ -67,6 +68,7 @@ def stable_diffusion_2(
6768
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
6869
be used. Default `None`.
6970
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
71+
mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False.
7072
fsdp (bool): Whether to use FSDP. Defaults to True.
7173
clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None.
7274
"""
@@ -123,6 +125,7 @@ def stable_diffusion_2(
123125
loss_bins=loss_bins,
124126
precomputed_latents=precomputed_latents,
125127
encode_latents_in_fp16=encode_latents_in_fp16,
128+
mask_pad_tokens=mask_pad_tokens,
126129
fsdp=fsdp,
127130
)
128131
if torch.cuda.is_available():
@@ -156,6 +159,7 @@ def stable_diffusion_xl(
156159
loss_bins: Optional[List] = None,
157160
precomputed_latents: bool = False,
158161
encode_latents_in_fp16: bool = True,
162+
mask_pad_tokens: bool = False,
159163
fsdp: bool = True,
160164
clip_qkv: Optional[float] = 6.0,
161165
):
@@ -188,6 +192,7 @@ def stable_diffusion_xl(
188192
[(0, 1)].
189193
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
190194
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
195+
mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False.
191196
fsdp (bool): Whether to use FSDP. Defaults to True.
192197
clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to 6.0. Improves stability
193198
of training.
@@ -259,6 +264,7 @@ def stable_diffusion_xl(
259264
loss_bins=loss_bins,
260265
precomputed_latents=precomputed_latents,
261266
encode_latents_in_fp16=encode_latents_in_fp16,
267+
mask_pad_tokens=mask_pad_tokens,
262268
fsdp=fsdp,
263269
sdxl=True,
264270
)

diffusion/models/stable_diffusion.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class StableDiffusion(ComposerModel):
6464
Default: `False`.
6565
encode_latents_in_fp16 (bool): whether to encode latents in fp16.
6666
Default: `False`.
67+
mask_pad_tokens (bool): whether to mask pad tokens in cross attention.
68+
Default: `False`.
6769
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
6870
"""
6971

@@ -88,6 +90,7 @@ def __init__(self,
8890
text_latents_key: str = 'caption_latents',
8991
precomputed_latents: bool = False,
9092
encode_latents_in_fp16: bool = False,
93+
mask_pad_tokens: bool = False,
9194
fsdp: bool = False,
9295
sdxl: bool = False):
9396
super().__init__()
@@ -103,6 +106,7 @@ def __init__(self,
103106
self.image_key = image_key
104107
self.image_latents_key = image_latents_key
105108
self.precomputed_latents = precomputed_latents
109+
self.mask_pad_tokens = mask_pad_tokens
106110
self.sdxl = sdxl
107111
if self.sdxl:
108112
self.latent_scale = 0.13025
@@ -152,6 +156,7 @@ def __init__(self,
152156
self.text_key = text_key
153157
self.text_latents_key = text_latents_key
154158
self.encode_latents_in_fp16 = encode_latents_in_fp16
159+
self.mask_pad_tokens = mask_pad_tokens
155160
# freeze text_encoder during diffusion training
156161
self.text_encoder.requires_grad_(False)
157162
self.vae.requires_grad_(False)
@@ -206,6 +211,12 @@ def forward(self, batch):
206211
if pooled_conditioning is not None:
207212
pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1)
208213

214+
# Attention mask if needed
215+
if self.mask_pad_tokens and 'attention_mask' in batch.keys():
216+
encoder_attention_mask = batch['attention_mask']
217+
else:
218+
encoder_attention_mask = None
219+
209220
# Sample the diffusion timesteps
210221
timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device)
211222
# Add noise to the inputs (forward diffusion)
@@ -234,7 +245,10 @@ def forward(self, batch):
234245
added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids}
235246

236247
# Forward through the model
237-
return self.unet(noised_latents, timesteps, conditioning,
248+
return self.unet(noised_latents,
249+
timesteps,
250+
conditioning,
251+
encoder_attention_mask=encoder_attention_mask,
238252
added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps
239253

240254
def loss(self, outputs, batch):
@@ -252,6 +266,12 @@ def eval_forward(self, batch, outputs=None):
252266
prompts = batch[self.text_key]
253267
height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1]
254268

269+
# Attention mask if needed
270+
if self.mask_pad_tokens and 'attention_mask' in batch.keys():
271+
encoder_attention_mask = batch['attention_mask']
272+
else:
273+
encoder_attention_mask = None
274+
255275
# If SDXL, add eval-time micro-conditioning to batch
256276
if self.sdxl:
257277
device = self.unet.device
@@ -266,6 +286,7 @@ def eval_forward(self, batch, outputs=None):
266286
generated_images = {}
267287
for guidance_scale in self.val_guidance_scales:
268288
gen_images = self.generate(tokenized_prompts=prompts,
289+
tokenized_prompts_pad_mask=encoder_attention_mask,
269290
height=height,
270291
width=width,
271292
guidance_scale=guidance_scale,
@@ -339,6 +360,8 @@ def generate(
339360
negative_prompt: Optional[list] = None,
340361
tokenized_prompts: Optional[torch.LongTensor] = None,
341362
tokenized_negative_prompts: Optional[torch.LongTensor] = None,
363+
tokenized_prompts_pad_mask: Optional[torch.LongTensor] = None,
364+
tokenized_negative_prompts_pad_mask: Optional[torch.LongTensor] = None,
342365
prompt_embeds: Optional[torch.FloatTensor] = None,
343366
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
344367
height: Optional[int] = None,
@@ -369,6 +392,10 @@ def generate(
369392
otherwise will be of shape [B, max_length]. Default: `None`.
370393
tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative
371394
prompts instead of string prompts. Default: `None`.
395+
tokenized_prompts_pad_mask (torch.LongTensor): Optionally pass padding mask for
396+
pre-tokenized prompts. Default `None`.
397+
tokenized_negative_prompts_pad_mask (torch.LongTensor): Optionall pass padding mask for
398+
pre-tokenized negative prompts. Default `None`.
372399
prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead
373400
of string prompts. If both prompt and prompt_embeds
374401
are passed, prompt_embeds will be used. Default: `None`.
@@ -423,22 +450,24 @@ def generate(
423450

424451
do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore
425452

426-
text_embeddings, pooled_text_embeddings = self._prepare_text_embeddings(prompt, tokenized_prompts,
427-
prompt_embeds, num_images_per_prompt)
453+
text_embeddings, pooled_text_embeddings, pad_attn_mask = self._prepare_text_embeddings(
454+
prompt, tokenized_prompts, tokenized_prompts_pad_mask, prompt_embeds, num_images_per_prompt)
428455
batch_size = len(text_embeddings) # len prompts * num_images_per_prompt
429456
# classifier free guidance + negative prompts
430457
# negative prompt is given in place of the unconditional input in classifier free guidance
431-
pooled_embeddings = None
458+
pooled_embeddings, encoder_attn_mask = pooled_text_embeddings, pad_attn_mask
432459
if do_classifier_free_guidance:
433460
if not negative_prompt and not tokenized_negative_prompts and not negative_prompt_embeds and zero_out_negative_prompt:
434461
# Negative prompt is empty and we want to zero it out
435462
unconditional_embeddings = torch.zeros_like(text_embeddings)
436463
pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None
464+
uncond_pad_attn_mask = torch.zeros_like(pad_attn_mask) if pad_attn_mask is not None else None
437465
else:
438466
if not negative_prompt:
439467
negative_prompt = [''] * (batch_size // num_images_per_prompt) # type: ignore
440-
unconditional_embeddings, pooled_unconditional_embeddings = self._prepare_text_embeddings(
441-
negative_prompt, tokenized_negative_prompts, negative_prompt_embeds, num_images_per_prompt)
468+
unconditional_embeddings, pooled_unconditional_embeddings, uncond_pad_attn_mask = self._prepare_text_embeddings(
469+
negative_prompt, tokenized_negative_prompts, tokenized_negative_prompts_pad_mask,
470+
negative_prompt_embeds, num_images_per_prompt)
442471

443472
# concat uncond + prompt
444473
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings])
@@ -448,6 +477,9 @@ def generate(
448477
if self.sdxl:
449478
pooled_embeddings = pooled_text_embeddings
450479

480+
if pad_attn_mask is not None:
481+
encoder_attn_mask = torch.cat([uncond_pad_attn_mask, pad_attn_mask]) # type: ignore
482+
451483
# prepare for diffusion generation process
452484
latents = torch.randn(
453485
(batch_size, self.unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
@@ -488,6 +520,7 @@ def generate(
488520
pred = self.unet(latent_model_input,
489521
t,
490522
encoder_hidden_states=text_embeddings,
523+
encoder_attention_mask=encoder_attn_mask,
491524
added_cond_kwargs=added_cond_kwargs).sample
492525

493526
if do_classifier_free_guidance:
@@ -510,20 +543,28 @@ def generate(
510543
image = (image / 2 + 0.5).clamp(0, 1)
511544
return image.detach() # (batch*num_images_per_prompt, channel, h, w)
512545

513-
def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num_images_per_prompt):
546+
def _prepare_text_embeddings(self, prompt, tokenized_prompts, tokenized_pad_mask, prompt_embeds,
547+
num_images_per_prompt):
514548
"""Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt."""
515549
device = self.text_encoder.device
516550
pooled_text_embeddings = None
517551
if prompt_embeds is None:
518552
max_length = None if self.sdxl else self.tokenizer.model_max_length
519553
if tokenized_prompts is None:
520-
tokenized_prompts = self.tokenizer(prompt,
521-
padding='max_length',
522-
max_length=max_length,
523-
truncation=True,
524-
return_tensors='pt').input_ids
554+
tokenized_out = self.tokenizer(prompt,
555+
padding='max_length',
556+
max_length=max_length,
557+
truncation=True,
558+
return_tensors='pt')
559+
tokenized_prompts = tokenized_out.input_ids
560+
if self.mask_pad_tokens:
561+
tokenized_pad_mask = tokenized_out.attention_mask
525562
if self.sdxl:
526563
tokenized_prompts = torch.stack([tokenized_prompts[0], tokenized_prompts[1]], dim=1)
564+
if self.mask_pad_tokens:
565+
# For cross attention mask, take union of masks (want [B, 77])
566+
tokenized_pad_mask = torch.logical_or(tokenized_pad_mask[0], tokenized_pad_mask[1]).to(
567+
tokenized_pad_mask[0].dtype).to(device)
527568
if self.sdxl:
528569
text_embeddings, pooled_text_embeddings = self.text_encoder(
529570
[tokenized_prompts[:, 0, :].to(device), tokenized_prompts[:, 1, :].to(device)]) # type: ignore
@@ -539,10 +580,14 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num
539580
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # type: ignore
540581
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
541582

583+
if tokenized_pad_mask is not None:
584+
tokenized_pad_mask = tokenized_pad_mask.repeat(1, num_images_per_prompt, 1)
585+
tokenized_pad_mask = tokenized_pad_mask.view(bs_embed * num_images_per_prompt, seq_len) # [B, 77]
586+
542587
if self.sdxl and pooled_text_embeddings is not None:
543588
pooled_text_embeddings = pooled_text_embeddings.repeat(1, num_images_per_prompt)
544589
pooled_text_embeddings = pooled_text_embeddings.view(bs_embed * num_images_per_prompt, -1)
545-
return text_embeddings, pooled_text_embeddings
590+
return text_embeddings, pooled_text_embeddings, tokenized_pad_mask
546591

547592

548593
def _check_prompt_lenths(prompt, negative_prompt):

0 commit comments

Comments
 (0)