From 0881fe84184ad5970d93db39e76f4839579bd7e6 Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 2 Dec 2025 18:22:25 +0100 Subject: [PATCH] split tensors inside the transformer blocks to avoid checkpointing issues --- .../models/transformers/transformer_flux2.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index c10bf3ed4f7b..28a46884e941 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -390,6 +390,15 @@ def forward( return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) +def split_mod(mod: torch.Tensor, mod_param_sets: int): + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets)) + + + class Flux2SingleTransformerBlock(nn.Module): def __init__( self, @@ -430,6 +439,8 @@ def forward( split_hidden_states: bool = False, text_seq_len: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + temb_mod_params = split_mod(temb_mod_params, 1)[0] + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already # concatenated if encoder_hidden_states is not None: @@ -504,6 +515,9 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: joint_attention_kwargs = joint_attention_kwargs or {} + temb_mod_params_img = split_mod(temb_mod_params_img, 2) + temb_mod_params_txt = split_mod(temb_mod_params_txt, 2) + # Modulation parameters shape: [1, 1, self.dim] (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img @@ -621,11 +635,12 @@ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, mod = self.act_fn(temb) mod = self.linear(mod) - if mod.ndim == 2: - mod = mod.unsqueeze(1) - mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) - # Return tuple of 3-tuples of modulation params shift/scale/gate - return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + return mod +# if mod.ndim == 2: +# mod = mod.unsqueeze(1) +# mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) +# # Return tuple of 3-tuples of modulation params shift/scale/gate +# return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) class Flux2Transformer2DModel( @@ -821,7 +836,7 @@ def forward( double_stream_mod_img = self.double_stream_modulation_img(temb) double_stream_mod_txt = self.double_stream_modulation_txt(temb) - single_stream_mod = self.single_stream_modulation(temb)[0] + single_stream_mod = self.single_stream_modulation(temb) # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) hidden_states = self.x_embedder(hidden_states)