Skip to content

[Hi Dream] follow-up #11296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 100 additions & 61 deletions src/diffusers/models/transformers/transformer_hidream_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...loaders import PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
Expand Down Expand Up @@ -686,46 +686,108 @@ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_train
x = torch.cat(x_arr, dim=0)
return x

def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.config.patch_size * self.config.patch_size
if isinstance(x, torch.Tensor):
B, C = x.shape[0], x.shape[1]
device = x.device
dtype = x.dtype
def patchify(self, hidden_states):
batch_size, channels, height, width = hidden_states.shape
patch_size = self.config.patch_size
patch_height, patch_width = height // patch_size, width // patch_size
device = hidden_states.device
dtype = hidden_states.dtype

# create img_sizes
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)

# create hidden_states_masks
if hidden_states.shape[-2] != hidden_states.shape[-1]:
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
hidden_states_masks[:, : patch_height * patch_width] = 1.0
else:
B, C = len(x), x[0].shape[0]
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
hidden_states_masks = None

# create img_ids
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
row_indices = torch.arange(patch_height, device=device)[:, None]
col_indices = torch.arange(patch_width, device=device)[None, :]
img_ids[..., 1] = img_ids[..., 1] + row_indices
img_ids[..., 2] = img_ids[..., 2] + col_indices
img_ids = img_ids.reshape(patch_height * patch_width, -1)

if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
img_ids_pad[: patch_height * patch_width, :] = img_ids
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
else:
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)

# patchify hidden_states
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
out = torch.zeros(
(batch_size, channels, self.max_seq, patch_size * patch_size),
dtype=dtype,
device=device,
)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height * patch_width, patch_size * patch_size
)
out[:, :, 0 : patch_height * patch_width] = hidden_states
hidden_states = out
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch_size, self.max_seq, patch_size * patch_size * channels
)

if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
B, C, S, _ = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
elif isinstance(x, torch.Tensor):
B, C, Hp1, Wp2 = x.shape
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
x = x.permute(0, 2, 4, 3, 5, 1)
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
img_sizes = [[pH, pW]] * B
x_masks = None
else:
raise NotImplementedError
return x, x_masks, img_sizes
# Handle square latents
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
hidden_states = hidden_states.reshape(
batch_size, patch_height * patch_width, patch_size * patch_size * channels
)

return hidden_states, hidden_states_masks, img_sizes, img_ids

def forward(
self,
hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_t5: torch.Tensor = None,
encoder_hidden_states_llama3: torch.Tensor = None,
pooled_embeds: torch.Tensor = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
img_ids: Optional[torch.Tensor] = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)

if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]

if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
deprecation_message = (
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
)
deprecate("img_ids", "0.34.0", deprecation_message)

if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
elif hidden_states_masks is not None and hidden_states.ndim != 3:
raise ValueError(
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
Expand All @@ -745,42 +807,19 @@ def forward(
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype

if hidden_states.shape[-2] != hidden_states.shape[-1]:
B, C, H, W = hidden_states.shape
patch_size = self.config.patch_size
pH, pW = H // patch_size, W // patch_size
out = torch.zeros(
(B, C, self.max_seq, patch_size * patch_size),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
out[:, :, 0 : pH * pW] = hidden_states
hidden_states = out
# Patchify the input
if hidden_states_masks is None:
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)

# Embed the hidden states
hidden_states = self.x_embedder(hidden_states)

# 0. time
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder

hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if hidden_states_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = (
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
hidden_states = self.x_embedder(hidden_states)

T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]

if self.caption_projection is not None:
new_encoder_hidden_states = []
Expand All @@ -789,9 +828,9 @@ def forward(
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(encoder_hidden_states_t5)

txt_ids = torch.zeros(
batch_size,
Expand Down
Loading