diff --git a/models/modeling_mmada.py b/models/modeling_mmada.py index ddcf892..f12731e 100644 --- a/models/modeling_mmada.py +++ b/models/modeling_mmada.py @@ -587,8 +587,8 @@ def t2i_generate_decoding_stepwise( uncond_input_ids = torch.cat( [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) model_input = torch.cat([input_ids, uncond_input_ids]) - attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) - attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) logits = self(model_input, attention_bias=attention_bias).logits # print(f"logits.shape: {logits.shape}") cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)