@@ -64,6 +64,8 @@ class StableDiffusion(ComposerModel):
64
64
Default: `False`.
65
65
encode_latents_in_fp16 (bool): whether to encode latents in fp16.
66
66
Default: `False`.
67
+ mask_pad_tokens (bool): whether to mask pad tokens in cross attention.
68
+ Default: `False`.
67
69
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
68
70
"""
69
71
@@ -88,6 +90,7 @@ def __init__(self,
88
90
text_latents_key : str = 'caption_latents' ,
89
91
precomputed_latents : bool = False ,
90
92
encode_latents_in_fp16 : bool = False ,
93
+ mask_pad_tokens : bool = False ,
91
94
fsdp : bool = False ,
92
95
sdxl : bool = False ):
93
96
super ().__init__ ()
@@ -103,6 +106,7 @@ def __init__(self,
103
106
self .image_key = image_key
104
107
self .image_latents_key = image_latents_key
105
108
self .precomputed_latents = precomputed_latents
109
+ self .mask_pad_tokens = mask_pad_tokens
106
110
self .sdxl = sdxl
107
111
if self .sdxl :
108
112
self .latent_scale = 0.13025
@@ -152,6 +156,7 @@ def __init__(self,
152
156
self .text_key = text_key
153
157
self .text_latents_key = text_latents_key
154
158
self .encode_latents_in_fp16 = encode_latents_in_fp16
159
+ self .mask_pad_tokens = mask_pad_tokens
155
160
# freeze text_encoder during diffusion training
156
161
self .text_encoder .requires_grad_ (False )
157
162
self .vae .requires_grad_ (False )
@@ -206,6 +211,12 @@ def forward(self, batch):
206
211
if pooled_conditioning is not None :
207
212
pooled_conditioning *= batch ['drop_caption_mask' ].view (- 1 , 1 )
208
213
214
+ # Attention mask if needed
215
+ if self .mask_pad_tokens and 'attention_mask' in batch .keys ():
216
+ encoder_attention_mask = batch ['attention_mask' ]
217
+ else :
218
+ encoder_attention_mask = None
219
+
209
220
# Sample the diffusion timesteps
210
221
timesteps = torch .randint (0 , len (self .noise_scheduler ), (latents .shape [0 ],), device = latents .device )
211
222
# Add noise to the inputs (forward diffusion)
@@ -234,7 +245,10 @@ def forward(self, batch):
234
245
added_cond_kwargs = {'text_embeds' : add_text_embeds , 'time_ids' : add_time_ids }
235
246
236
247
# Forward through the model
237
- return self .unet (noised_latents , timesteps , conditioning ,
248
+ return self .unet (noised_latents ,
249
+ timesteps ,
250
+ conditioning ,
251
+ encoder_attention_mask = encoder_attention_mask ,
238
252
added_cond_kwargs = added_cond_kwargs )['sample' ], targets , timesteps
239
253
240
254
def loss (self , outputs , batch ):
@@ -252,6 +266,12 @@ def eval_forward(self, batch, outputs=None):
252
266
prompts = batch [self .text_key ]
253
267
height , width = batch [self .image_key ].shape [- 2 ], batch [self .image_key ].shape [- 1 ]
254
268
269
+ # Attention mask if needed
270
+ if self .mask_pad_tokens and 'attention_mask' in batch .keys ():
271
+ encoder_attention_mask = batch ['attention_mask' ]
272
+ else :
273
+ encoder_attention_mask = None
274
+
255
275
# If SDXL, add eval-time micro-conditioning to batch
256
276
if self .sdxl :
257
277
device = self .unet .device
@@ -266,6 +286,7 @@ def eval_forward(self, batch, outputs=None):
266
286
generated_images = {}
267
287
for guidance_scale in self .val_guidance_scales :
268
288
gen_images = self .generate (tokenized_prompts = prompts ,
289
+ tokenized_prompts_pad_mask = encoder_attention_mask ,
269
290
height = height ,
270
291
width = width ,
271
292
guidance_scale = guidance_scale ,
@@ -339,6 +360,8 @@ def generate(
339
360
negative_prompt : Optional [list ] = None ,
340
361
tokenized_prompts : Optional [torch .LongTensor ] = None ,
341
362
tokenized_negative_prompts : Optional [torch .LongTensor ] = None ,
363
+ tokenized_prompts_pad_mask : Optional [torch .LongTensor ] = None ,
364
+ tokenized_negative_prompts_pad_mask : Optional [torch .LongTensor ] = None ,
342
365
prompt_embeds : Optional [torch .FloatTensor ] = None ,
343
366
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
344
367
height : Optional [int ] = None ,
@@ -369,6 +392,10 @@ def generate(
369
392
otherwise will be of shape [B, max_length]. Default: `None`.
370
393
tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative
371
394
prompts instead of string prompts. Default: `None`.
395
+ tokenized_prompts_pad_mask (torch.LongTensor): Optionally pass padding mask for
396
+ pre-tokenized prompts. Default `None`.
397
+ tokenized_negative_prompts_pad_mask (torch.LongTensor): Optionall pass padding mask for
398
+ pre-tokenized negative prompts. Default `None`.
372
399
prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead
373
400
of string prompts. If both prompt and prompt_embeds
374
401
are passed, prompt_embeds will be used. Default: `None`.
@@ -423,22 +450,24 @@ def generate(
423
450
424
451
do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore
425
452
426
- text_embeddings , pooled_text_embeddings = self ._prepare_text_embeddings (prompt , tokenized_prompts ,
427
- prompt_embeds , num_images_per_prompt )
453
+ text_embeddings , pooled_text_embeddings , pad_attn_mask = self ._prepare_text_embeddings (
454
+ prompt , tokenized_prompts , tokenized_prompts_pad_mask , prompt_embeds , num_images_per_prompt )
428
455
batch_size = len (text_embeddings ) # len prompts * num_images_per_prompt
429
456
# classifier free guidance + negative prompts
430
457
# negative prompt is given in place of the unconditional input in classifier free guidance
431
- pooled_embeddings = None
458
+ pooled_embeddings , encoder_attn_mask = pooled_text_embeddings , pad_attn_mask
432
459
if do_classifier_free_guidance :
433
460
if not negative_prompt and not tokenized_negative_prompts and not negative_prompt_embeds and zero_out_negative_prompt :
434
461
# Negative prompt is empty and we want to zero it out
435
462
unconditional_embeddings = torch .zeros_like (text_embeddings )
436
463
pooled_unconditional_embeddings = torch .zeros_like (pooled_text_embeddings ) if self .sdxl else None
464
+ uncond_pad_attn_mask = torch .zeros_like (pad_attn_mask ) if pad_attn_mask is not None else None
437
465
else :
438
466
if not negative_prompt :
439
467
negative_prompt = ['' ] * (batch_size // num_images_per_prompt ) # type: ignore
440
- unconditional_embeddings , pooled_unconditional_embeddings = self ._prepare_text_embeddings (
441
- negative_prompt , tokenized_negative_prompts , negative_prompt_embeds , num_images_per_prompt )
468
+ unconditional_embeddings , pooled_unconditional_embeddings , uncond_pad_attn_mask = self ._prepare_text_embeddings (
469
+ negative_prompt , tokenized_negative_prompts , tokenized_negative_prompts_pad_mask ,
470
+ negative_prompt_embeds , num_images_per_prompt )
442
471
443
472
# concat uncond + prompt
444
473
text_embeddings = torch .cat ([unconditional_embeddings , text_embeddings ])
@@ -448,6 +477,9 @@ def generate(
448
477
if self .sdxl :
449
478
pooled_embeddings = pooled_text_embeddings
450
479
480
+ if pad_attn_mask is not None :
481
+ encoder_attn_mask = torch .cat ([uncond_pad_attn_mask , pad_attn_mask ]) # type: ignore
482
+
451
483
# prepare for diffusion generation process
452
484
latents = torch .randn (
453
485
(batch_size , self .unet .config .in_channels , height // vae_scale_factor , width // vae_scale_factor ),
@@ -488,6 +520,7 @@ def generate(
488
520
pred = self .unet (latent_model_input ,
489
521
t ,
490
522
encoder_hidden_states = text_embeddings ,
523
+ encoder_attention_mask = encoder_attn_mask ,
491
524
added_cond_kwargs = added_cond_kwargs ).sample
492
525
493
526
if do_classifier_free_guidance :
@@ -510,20 +543,28 @@ def generate(
510
543
image = (image / 2 + 0.5 ).clamp (0 , 1 )
511
544
return image .detach () # (batch*num_images_per_prompt, channel, h, w)
512
545
513
- def _prepare_text_embeddings (self , prompt , tokenized_prompts , prompt_embeds , num_images_per_prompt ):
546
+ def _prepare_text_embeddings (self , prompt , tokenized_prompts , tokenized_pad_mask , prompt_embeds ,
547
+ num_images_per_prompt ):
514
548
"""Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt."""
515
549
device = self .text_encoder .device
516
550
pooled_text_embeddings = None
517
551
if prompt_embeds is None :
518
552
max_length = None if self .sdxl else self .tokenizer .model_max_length
519
553
if tokenized_prompts is None :
520
- tokenized_prompts = self .tokenizer (prompt ,
521
- padding = 'max_length' ,
522
- max_length = max_length ,
523
- truncation = True ,
524
- return_tensors = 'pt' ).input_ids
554
+ tokenized_out = self .tokenizer (prompt ,
555
+ padding = 'max_length' ,
556
+ max_length = max_length ,
557
+ truncation = True ,
558
+ return_tensors = 'pt' )
559
+ tokenized_prompts = tokenized_out .input_ids
560
+ if self .mask_pad_tokens :
561
+ tokenized_pad_mask = tokenized_out .attention_mask
525
562
if self .sdxl :
526
563
tokenized_prompts = torch .stack ([tokenized_prompts [0 ], tokenized_prompts [1 ]], dim = 1 )
564
+ if self .mask_pad_tokens :
565
+ # For cross attention mask, take union of masks (want [B, 77])
566
+ tokenized_pad_mask = torch .logical_or (tokenized_pad_mask [0 ], tokenized_pad_mask [1 ]).to (
567
+ tokenized_pad_mask [0 ].dtype ).to (device )
527
568
if self .sdxl :
528
569
text_embeddings , pooled_text_embeddings = self .text_encoder (
529
570
[tokenized_prompts [:, 0 , :].to (device ), tokenized_prompts [:, 1 , :].to (device )]) # type: ignore
@@ -539,10 +580,14 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num
539
580
text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 ) # type: ignore
540
581
text_embeddings = text_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
541
582
583
+ if tokenized_pad_mask is not None :
584
+ tokenized_pad_mask = tokenized_pad_mask .repeat (1 , num_images_per_prompt , 1 )
585
+ tokenized_pad_mask = tokenized_pad_mask .view (bs_embed * num_images_per_prompt , seq_len ) # [B, 77]
586
+
542
587
if self .sdxl and pooled_text_embeddings is not None :
543
588
pooled_text_embeddings = pooled_text_embeddings .repeat (1 , num_images_per_prompt )
544
589
pooled_text_embeddings = pooled_text_embeddings .view (bs_embed * num_images_per_prompt , - 1 )
545
- return text_embeddings , pooled_text_embeddings
590
+ return text_embeddings , pooled_text_embeddings , tokenized_pad_mask
546
591
547
592
548
593
def _check_prompt_lenths (prompt , negative_prompt ):
0 commit comments