Skip to content

Support Expert loss for HiDream #11673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions src/diffusers/models/transformers/transformer_hidream_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -6,9 +7,8 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
Expand All @@ -17,6 +17,29 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@dataclass
class HiDreamImageModelOutput(BaseOutput):
sample: torch.Tensor
double_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None
single_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None


class AddAuxiliaryLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
assert loss.numel() == 1
ctx.dtype = loss.dtype
ctx.required_aux_loss = loss.requires_grad
return x

@staticmethod
def backward(ctx, grad_output):
grad_loss = None
if ctx.required_aux_loss:
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
return grad_output, grad_loss


class HiDreamImageFeedForwardSwiGLU(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -332,7 +355,6 @@ def forward(self, hidden_states):
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)

Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
Expand Down Expand Up @@ -379,11 +401,11 @@ def forward(self, x):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
# y = AddAuxiliaryLoss.apply(y, aux_loss)
y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
return y, aux_loss

@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
Expand Down Expand Up @@ -481,9 +503,10 @@ def forward(
# 2. Feed-forward
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
ff_output_i, aux_loss = self.ff_i(norm_hidden_states.to(dtype=wtype))
ff_output_i = gate_mlp_i * ff_output_i
hidden_states = ff_output_i + hidden_states
return hidden_states
return hidden_states, aux_loss


@maybe_allow_in_graph
Expand Down Expand Up @@ -573,11 +596,12 @@ def forward(
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t

ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
ff_output_i, aux_loss = self.ff_i(norm_hidden_states)
ff_output_i = gate_mlp_i * ff_output_i
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
hidden_states = ff_output_i + hidden_states
encoder_hidden_states = ff_output_t + encoder_hidden_states
return hidden_states, encoder_hidden_states
return hidden_states, encoder_hidden_states, aux_loss


class HiDreamBlock(nn.Module):
Expand Down Expand Up @@ -785,6 +809,7 @@ def forward(
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
return_auxiliary_loss: bool = False,
**kwargs,
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
Expand Down Expand Up @@ -866,15 +891,19 @@ def forward(

# 2. Blocks
block_id = 0
double_blocks_aux_losses = []
single_blocks_aux_losses = []

initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]

for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat(
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
hidden_states, initial_encoder_hidden_states, aux_loss = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
Expand All @@ -883,14 +912,15 @@ def forward(
image_rotary_emb,
)
else:
hidden_states, initial_encoder_hidden_states = block(
hidden_states, initial_encoder_hidden_states, aux_loss = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=cur_encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
double_blocks_aux_losses.append(aux_loss)
block_id += 1

image_tokens_seq_len = hidden_states.shape[1]
Expand All @@ -908,7 +938,7 @@ def forward(
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
hidden_states, aux_loss = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
Expand All @@ -917,14 +947,15 @@ def forward(
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states, aux_loss = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=None,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
single_blocks_aux_losses.append(aux_loss)
block_id += 1

hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
Expand All @@ -938,5 +969,13 @@ def forward(
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return_values = (output,)
if return_auxiliary_loss:
return_values += (double_blocks_aux_losses, single_blocks_aux_losses)
return return_values

return HiDreamImageModelOutput(
sample=output,
double_blocks_auxiliary_loss=double_blocks_aux_losses,
single_blocks_auxiliary_loss=single_blocks_aux_losses,
)
Loading