Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions src/diffusers/models/transformers/transformer_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,15 @@ def forward(
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)


def split_mod(mod: torch.Tensor, mod_param_sets: int):
if mod.ndim == 2:
mod = mod.unsqueeze(1)
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
# Return tuple of 3-tuples of modulation params shift/scale/gate
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))



class Flux2SingleTransformerBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -430,6 +439,8 @@ def forward(
split_hidden_states: bool = False,
text_seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb_mod_params = split_mod(temb_mod_params, 1)[0]

# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
# concatenated
if encoder_hidden_states is not None:
Expand Down Expand Up @@ -504,6 +515,9 @@ def forward(
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
temb_mod_params_img = split_mod(temb_mod_params_img, 2)
temb_mod_params_txt = split_mod(temb_mod_params_txt, 2)


# Modulation parameters shape: [1, 1, self.dim]
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
Expand Down Expand Up @@ -621,11 +635,12 @@ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor,
mod = self.act_fn(temb)
mod = self.linear(mod)

if mod.ndim == 2:
mod = mod.unsqueeze(1)
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
# Return tuple of 3-tuples of modulation params shift/scale/gate
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
return mod
# if mod.ndim == 2:
# mod = mod.unsqueeze(1)
# mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
# # Return tuple of 3-tuples of modulation params shift/scale/gate
# return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))


class Flux2Transformer2DModel(
Expand Down Expand Up @@ -821,7 +836,7 @@ def forward(

double_stream_mod_img = self.double_stream_modulation_img(temb)
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb)[0]
single_stream_mod = self.single_stream_modulation(temb)

# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
hidden_states = self.x_embedder(hidden_states)
Expand Down