Skip to content

Commit 0567932

Browse files
authored
[Hi Dream] follow-up (#11296)
* add
1 parent 29d2afb commit 0567932

File tree

3 files changed

+421
-202
lines changed

3 files changed

+421
-202
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 100 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ...loaders import PeftAdapterMixin
99
from ...models.modeling_outputs import Transformer2DModelOutput
1010
from ...models.modeling_utils import ModelMixin
11-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
11+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
1212
from ...utils.torch_utils import maybe_allow_in_graph
1313
from ..attention import Attention
1414
from ..embeddings import TimestepEmbedding, Timesteps
@@ -686,46 +686,108 @@ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_train
686686
x = torch.cat(x_arr, dim=0)
687687
return x
688688

689-
def patchify(self, x, max_seq, img_sizes=None):
690-
pz2 = self.config.patch_size * self.config.patch_size
691-
if isinstance(x, torch.Tensor):
692-
B, C = x.shape[0], x.shape[1]
693-
device = x.device
694-
dtype = x.dtype
689+
def patchify(self, hidden_states):
690+
batch_size, channels, height, width = hidden_states.shape
691+
patch_size = self.config.patch_size
692+
patch_height, patch_width = height // patch_size, width // patch_size
693+
device = hidden_states.device
694+
dtype = hidden_states.dtype
695+
696+
# create img_sizes
697+
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
698+
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
699+
700+
# create hidden_states_masks
701+
if hidden_states.shape[-2] != hidden_states.shape[-1]:
702+
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
703+
hidden_states_masks[:, : patch_height * patch_width] = 1.0
695704
else:
696-
B, C = len(x), x[0].shape[0]
697-
device = x[0].device
698-
dtype = x[0].dtype
699-
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
705+
hidden_states_masks = None
706+
707+
# create img_ids
708+
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
709+
row_indices = torch.arange(patch_height, device=device)[:, None]
710+
col_indices = torch.arange(patch_width, device=device)[None, :]
711+
img_ids[..., 1] = img_ids[..., 1] + row_indices
712+
img_ids[..., 2] = img_ids[..., 2] + col_indices
713+
img_ids = img_ids.reshape(patch_height * patch_width, -1)
714+
715+
if hidden_states.shape[-2] != hidden_states.shape[-1]:
716+
# Handle non-square latents
717+
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
718+
img_ids_pad[: patch_height * patch_width, :] = img_ids
719+
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
720+
else:
721+
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
722+
723+
# patchify hidden_states
724+
if hidden_states.shape[-2] != hidden_states.shape[-1]:
725+
# Handle non-square latents
726+
out = torch.zeros(
727+
(batch_size, channels, self.max_seq, patch_size * patch_size),
728+
dtype=dtype,
729+
device=device,
730+
)
731+
hidden_states = hidden_states.reshape(
732+
batch_size, channels, patch_height, patch_size, patch_width, patch_size
733+
)
734+
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
735+
hidden_states = hidden_states.reshape(
736+
batch_size, channels, patch_height * patch_width, patch_size * patch_size
737+
)
738+
out[:, :, 0 : patch_height * patch_width] = hidden_states
739+
hidden_states = out
740+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
741+
batch_size, self.max_seq, patch_size * patch_size * channels
742+
)
700743

701-
if img_sizes is not None:
702-
for i, img_size in enumerate(img_sizes):
703-
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
704-
B, C, S, _ = x.shape
705-
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
706-
elif isinstance(x, torch.Tensor):
707-
B, C, Hp1, Wp2 = x.shape
708-
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
709-
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
710-
x = x.permute(0, 2, 4, 3, 5, 1)
711-
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
712-
img_sizes = [[pH, pW]] * B
713-
x_masks = None
714744
else:
715-
raise NotImplementedError
716-
return x, x_masks, img_sizes
745+
# Handle square latents
746+
hidden_states = hidden_states.reshape(
747+
batch_size, channels, patch_height, patch_size, patch_width, patch_size
748+
)
749+
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
750+
hidden_states = hidden_states.reshape(
751+
batch_size, patch_height * patch_width, patch_size * patch_size * channels
752+
)
753+
754+
return hidden_states, hidden_states_masks, img_sizes, img_ids
717755

718756
def forward(
719757
self,
720758
hidden_states: torch.Tensor,
721759
timesteps: torch.LongTensor = None,
722-
encoder_hidden_states: torch.Tensor = None,
760+
encoder_hidden_states_t5: torch.Tensor = None,
761+
encoder_hidden_states_llama3: torch.Tensor = None,
723762
pooled_embeds: torch.Tensor = None,
724-
img_sizes: Optional[List[Tuple[int, int]]] = None,
725763
img_ids: Optional[torch.Tensor] = None,
764+
img_sizes: Optional[List[Tuple[int, int]]] = None,
765+
hidden_states_masks: Optional[torch.Tensor] = None,
726766
attention_kwargs: Optional[Dict[str, Any]] = None,
727767
return_dict: bool = True,
768+
**kwargs,
728769
):
770+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
771+
772+
if encoder_hidden_states is not None:
773+
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
774+
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
775+
encoder_hidden_states_t5 = encoder_hidden_states[0]
776+
encoder_hidden_states_llama3 = encoder_hidden_states[1]
777+
778+
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
779+
deprecation_message = (
780+
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
781+
)
782+
deprecate("img_ids", "0.34.0", deprecation_message)
783+
784+
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
785+
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
786+
elif hidden_states_masks is not None and hidden_states.ndim != 3:
787+
raise ValueError(
788+
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
789+
)
790+
729791
if attention_kwargs is not None:
730792
attention_kwargs = attention_kwargs.copy()
731793
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -745,42 +807,19 @@ def forward(
745807
batch_size = hidden_states.shape[0]
746808
hidden_states_type = hidden_states.dtype
747809

748-
if hidden_states.shape[-2] != hidden_states.shape[-1]:
749-
B, C, H, W = hidden_states.shape
750-
patch_size = self.config.patch_size
751-
pH, pW = H // patch_size, W // patch_size
752-
out = torch.zeros(
753-
(B, C, self.max_seq, patch_size * patch_size),
754-
dtype=hidden_states.dtype,
755-
device=hidden_states.device,
756-
)
757-
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
758-
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
759-
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
760-
out[:, :, 0 : pH * pW] = hidden_states
761-
hidden_states = out
810+
# Patchify the input
811+
if hidden_states_masks is None:
812+
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
813+
814+
# Embed the hidden states
815+
hidden_states = self.x_embedder(hidden_states)
762816

763817
# 0. time
764818
timesteps = self.t_embedder(timesteps, hidden_states_type)
765819
p_embedder = self.p_embedder(pooled_embeds)
766820
temb = timesteps + p_embedder
767821

768-
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
769-
if hidden_states_masks is None:
770-
pH, pW = img_sizes[0]
771-
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
772-
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
773-
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
774-
img_ids = (
775-
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
776-
.unsqueeze(0)
777-
.repeat(batch_size, 1, 1)
778-
)
779-
hidden_states = self.x_embedder(hidden_states)
780-
781-
T5_encoder_hidden_states = encoder_hidden_states[0]
782-
encoder_hidden_states = encoder_hidden_states[-1]
783-
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
822+
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
784823

785824
if self.caption_projection is not None:
786825
new_encoder_hidden_states = []
@@ -789,9 +828,9 @@ def forward(
789828
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
790829
new_encoder_hidden_states.append(enc_hidden_state)
791830
encoder_hidden_states = new_encoder_hidden_states
792-
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
793-
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
794-
encoder_hidden_states.append(T5_encoder_hidden_states)
831+
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
832+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
833+
encoder_hidden_states.append(encoder_hidden_states_t5)
795834

796835
txt_ids = torch.zeros(
797836
batch_size,

0 commit comments

Comments
 (0)