diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 04622a7e04b2..30f6c3a34c9d 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -8,7 +8,7 @@ from ...loaders import PeftAdapterMixin from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention from ..embeddings import TimestepEmbedding, Timesteps @@ -686,46 +686,108 @@ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_train x = torch.cat(x_arr, dim=0) return x - def patchify(self, x, max_seq, img_sizes=None): - pz2 = self.config.patch_size * self.config.patch_size - if isinstance(x, torch.Tensor): - B, C = x.shape[0], x.shape[1] - device = x.device - dtype = x.dtype + def patchify(self, hidden_states): + batch_size, channels, height, width = hidden_states.shape + patch_size = self.config.patch_size + patch_height, patch_width = height // patch_size, width // patch_size + device = hidden_states.device + dtype = hidden_states.dtype + + # create img_sizes + img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1) + img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1) + + # create hidden_states_masks + if hidden_states.shape[-2] != hidden_states.shape[-1]: + hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device) + hidden_states_masks[:, : patch_height * patch_width] = 1.0 else: - B, C = len(x), x[0].shape[0] - device = x[0].device - dtype = x[0].dtype - x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + hidden_states_masks = None + + # create img_ids + img_ids = torch.zeros(patch_height, patch_width, 3, device=device) + row_indices = torch.arange(patch_height, device=device)[:, None] + col_indices = torch.arange(patch_width, device=device)[None, :] + img_ids[..., 1] = img_ids[..., 1] + row_indices + img_ids[..., 2] = img_ids[..., 2] + col_indices + img_ids = img_ids.reshape(patch_height * patch_width, -1) + + if hidden_states.shape[-2] != hidden_states.shape[-1]: + # Handle non-square latents + img_ids_pad = torch.zeros(self.max_seq, 3, device=device) + img_ids_pad[: patch_height * patch_width, :] = img_ids + img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1) + else: + img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1) + + # patchify hidden_states + if hidden_states.shape[-2] != hidden_states.shape[-1]: + # Handle non-square latents + out = torch.zeros( + (batch_size, channels, self.max_seq, patch_size * patch_size), + dtype=dtype, + device=device, + ) + hidden_states = hidden_states.reshape( + batch_size, channels, patch_height, patch_size, patch_width, patch_size + ) + hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5) + hidden_states = hidden_states.reshape( + batch_size, channels, patch_height * patch_width, patch_size * patch_size + ) + out[:, :, 0 : patch_height * patch_width] = hidden_states + hidden_states = out + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch_size, self.max_seq, patch_size * patch_size * channels + ) - if img_sizes is not None: - for i, img_size in enumerate(img_sizes): - x_masks[i, 0 : img_size[0] * img_size[1]] = 1 - B, C, S, _ = x.shape - x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C) - elif isinstance(x, torch.Tensor): - B, C, Hp1, Wp2 = x.shape - pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size - x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size) - x = x.permute(0, 2, 4, 3, 5, 1) - x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C) - img_sizes = [[pH, pW]] * B - x_masks = None else: - raise NotImplementedError - return x, x_masks, img_sizes + # Handle square latents + hidden_states = hidden_states.reshape( + batch_size, channels, patch_height, patch_size, patch_width, patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1) + hidden_states = hidden_states.reshape( + batch_size, patch_height * patch_width, patch_size * patch_size * channels + ) + + return hidden_states, hidden_states_masks, img_sizes, img_ids def forward( self, hidden_states: torch.Tensor, timesteps: torch.LongTensor = None, - encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_t5: torch.Tensor = None, + encoder_hidden_states_llama3: torch.Tensor = None, pooled_embeds: torch.Tensor = None, - img_sizes: Optional[List[Tuple[int, int]]] = None, img_ids: Optional[torch.Tensor] = None, + img_sizes: Optional[List[Tuple[int, int]]] = None, + hidden_states_masks: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, + **kwargs, ): + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + + if encoder_hidden_states is not None: + deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead." + deprecate("encoder_hidden_states", "0.34.0", deprecation_message) + encoder_hidden_states_t5 = encoder_hidden_states[0] + encoder_hidden_states_llama3 = encoder_hidden_states[1] + + if img_ids is not None and img_sizes is not None and hidden_states_masks is None: + deprecation_message = ( + "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored." + ) + deprecate("img_ids", "0.34.0", deprecation_message) + + if hidden_states_masks is not None and (img_ids is None or img_sizes is None): + raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.") + elif hidden_states_masks is not None and hidden_states.ndim != 3: + raise ValueError( + "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)" + ) + if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -745,42 +807,19 @@ def forward( batch_size = hidden_states.shape[0] hidden_states_type = hidden_states.dtype - if hidden_states.shape[-2] != hidden_states.shape[-1]: - B, C, H, W = hidden_states.shape - patch_size = self.config.patch_size - pH, pW = H // patch_size, W // patch_size - out = torch.zeros( - (B, C, self.max_seq, patch_size * patch_size), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size) - hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5) - hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size) - out[:, :, 0 : pH * pW] = hidden_states - hidden_states = out + # Patchify the input + if hidden_states_masks is None: + hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states) + + # Embed the hidden states + hidden_states = self.x_embedder(hidden_states) # 0. time timesteps = self.t_embedder(timesteps, hidden_states_type) p_embedder = self.p_embedder(pooled_embeds) temb = timesteps + p_embedder - hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) - if hidden_states_masks is None: - pH, pW = img_sizes[0] - img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] - img_ids = ( - img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2]) - .unsqueeze(0) - .repeat(batch_size, 1, 1) - ) - hidden_states = self.x_embedder(hidden_states) - - T5_encoder_hidden_states = encoder_hidden_states[0] - encoder_hidden_states = encoder_hidden_states[-1] - encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers] + encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers] if self.caption_projection is not None: new_encoder_hidden_states = [] @@ -789,9 +828,9 @@ def forward( enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) new_encoder_hidden_states.append(enc_hidden_state) encoder_hidden_states = new_encoder_hidden_states - T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) - T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) - encoder_hidden_states.append(T5_encoder_hidden_states) + encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(encoder_hidden_states_t5) txt_ids = torch.zeros( batch_size, diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index b17329a19959..1a3315892272 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -15,7 +15,7 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, HiDreamImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import HiDreamImagePipelineOutput @@ -38,9 +38,6 @@ >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline - >>> scheduler = UniPCMultistepScheduler( - ... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True - ... ) >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( @@ -52,7 +49,6 @@ >>> pipe = HiDreamImagePipeline.from_pretrained( ... "HiDream-ai/HiDream-I1-Full", - ... scheduler=scheduler, ... tokenizer_4=tokenizer_4, ... text_encoder_4=text_encoder_4, ... torch_dtype=torch.bfloat16, @@ -148,7 +144,7 @@ def retrieve_timesteps( class HiDreamImagePipeline(DiffusionPipeline): model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"] def __init__( self, @@ -309,10 +305,10 @@ def _get_llama3_prompt_embeds( def encode_prompt( self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - prompt_3: Union[str, List[str]], - prompt_4: Union[str, List[str]], + prompt: Optional[Union[str, List[str]]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, num_images_per_prompt: int = 1, @@ -321,8 +317,10 @@ def encode_prompt( negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[List[torch.FloatTensor]] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None, + prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 128, @@ -332,120 +330,177 @@ def encode_prompt( if prompt is not None: batch_size = len(prompt) else: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] - - prompt_embeds, pooled_prompt_embeds = self._encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_3=prompt_3, - prompt_4=prompt_4, - device=device, - dtype=dtype, - num_images_per_prompt=num_images_per_prompt, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - max_sequence_length=max_sequence_length, - ) + batch_size = pooled_prompt_embeds.shape[0] - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - negative_prompt_3 = negative_prompt_3 or negative_prompt - negative_prompt_4 = negative_prompt_4 or negative_prompt + device = device or self._execution_device - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - negative_prompt_3 = ( - batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 - ) - negative_prompt_4 = ( - batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + if pooled_prompt_embeds is None: + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype ) - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if len(negative_prompt) > 1 and len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must be of length 1 or {batch_size}") - negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( - prompt=negative_prompt, - prompt_2=negative_prompt_2, - prompt_3=negative_prompt_3, - prompt_4=negative_prompt_4, - device=device, - dtype=dtype, - num_images_per_prompt=num_images_per_prompt, - prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=negative_pooled_prompt_embeds, - max_sequence_length=max_sequence_length, + negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype ) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - def _encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - prompt_3: Union[str, List[str]], - prompt_4: Union[str, List[str]], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[List[torch.FloatTensor]] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 128, - ): - device = device or self._execution_device - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] + if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1) if pooled_prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( - self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype - ) + if len(prompt_2) > 1 and len(prompt_2) != batch_size: + raise ValueError(f"prompt_2 must be of length 1 or {batch_size}") + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype ) + + if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + + if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size: + raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype + ) + + if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) - pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1 + ) - if prompt_embeds is None: + if prompt_embeds_t5 is None: prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + if len(prompt_3) > 1 and len(prompt_3) != batch_size: + raise ValueError(f"prompt_3 must be of length 1 or {batch_size}") + + prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) + + if prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_t5 is None: + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + + if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size: + raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}") + + negative_prompt_embeds_t5 = self._get_t5_prompt_embeds( + negative_prompt_3, max_sequence_length, device, dtype + ) + + if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + + if prompt_embeds_llama3 is None: prompt_4 = prompt_4 or prompt prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 - t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) - llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) + if len(prompt_4) > 1 and len(prompt_4) != batch_size: + raise ValueError(f"prompt_4 must be of length 1 or {batch_size}") + + prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) + + if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) - _, seq_len, _ = t5_prompt_embeds.shape - t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1) - t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None: + negative_prompt_4 = negative_prompt_4 or negative_prompt + negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + + if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size: + raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}") - _, _, seq_len, dim = llama3_prompt_embeds.shape - llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) - llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds( + negative_prompt_4, max_sequence_length, device, dtype + ) - prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] + if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + # duplicate pooled_prompt_embeds for each generation per prompt + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}") + prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}") + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + + if do_classifier_free_guidance: + # duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len = negative_pooled_prompt_embeds.shape + if bs_embed == 1 and batch_size > 1: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}") + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}") + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}") + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view( + -1, batch_size * num_images_per_prompt, seq_len, dim + ) - return prompt_embeds, pooled_prompt_embeds + return ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) def enable_vae_slicing(self): r""" @@ -476,6 +531,115 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + negative_prompt_4=None, + prompt_embeds_t5=None, + prompt_embeds_llama3=None, + negative_prompt_embeds_t5=None, + negative_prompt_embeds_llama3=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to" + " only forward one of the two." + ) + elif prompt_4 is not None and prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and pooled_prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_t5 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined." + ) + elif prompt is None and prompt_embeds_llama3 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)): + raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}") + + if negative_prompt is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:" + f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two." + ) + elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:" + f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two." + ) + + if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None: + if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape: + raise ValueError( + "`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but" + f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`" + f" {negative_pooled_prompt_embeds.shape}." + ) + if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None: + if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape: + raise ValueError( + "`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but" + f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`" + f" {negative_prompt_embeds_t5.shape}." + ) + if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None: + if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape: + raise ValueError( + "`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but" + f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`" + f" {negative_prompt_embeds_llama3.shape}." + ) + def prepare_latents( self, batch_size, @@ -542,8 +706,10 @@ def __call__( num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_t5: Optional[torch.FloatTensor] = None, + prompt_embeds_llama3: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_llama3: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -552,6 +718,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 128, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -649,6 +816,22 @@ def __call__( [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated. images. """ + + prompt_embeds = kwargs.get("prompt_embeds", None) + negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None) + + if prompt_embeds is not None: + deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead." + deprecate("prompt_embeds", "0.34.0", deprecation_message) + prompt_embeds_t5 = prompt_embeds[0] + prompt_embeds_llama3 = prompt_embeds[1] + + if negative_prompt_embeds is not None: + deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead." + deprecate("negative_prompt_embeds", "0.34.0", deprecation_message) + negative_prompt_embeds_t5 = negative_prompt_embeds[0] + negative_prompt_embeds_llama3 = negative_prompt_embeds[1] + height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -658,6 +841,25 @@ def __call__( scale = math.sqrt(scale) width, height = int(width * scale // division * division), int(height * scale // division * division) + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._interrupt = False @@ -667,17 +869,18 @@ def __call__( batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - elif prompt_embeds is not None: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] - else: - batch_size = 1 + elif pooled_prompt_embeds is not None: + batch_size = pooled_prompt_embeds.shape[0] device = self._execution_device + # 3. Encode prompt lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None ( - prompt_embeds, - negative_prompt_embeds, + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( @@ -690,8 +893,10 @@ def __call__( negative_prompt_3=negative_prompt_3, negative_prompt_4=negative_prompt_4, do_classifier_free_guidance=self.do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, @@ -701,13 +906,8 @@ def __call__( ) if self.do_classifier_free_guidance: - prompt_embeds_arr = [] - for n, p in zip(negative_prompt_embeds, prompt_embeds): - if len(n.shape) == 3: - prompt_embeds_arr.append(torch.cat([n, p], dim=0)) - else: - prompt_embeds_arr.append(torch.cat([n, p], dim=1)) - prompt_embeds = prompt_embeds_arr + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0) + prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare latent variables @@ -723,26 +923,6 @@ def __call__( latents, ) - if latents.shape[-2] != latents.shape[-1]: - B, C, H, W = latents.shape - pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size - - img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) - img_ids = torch.zeros(pH, pW, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] - img_ids = img_ids.reshape(pH * pW, -1) - img_ids_pad = torch.zeros(self.transformer.max_seq, 3) - img_ids_pad[: pH * pW, :] = img_ids - - img_sizes = img_sizes.unsqueeze(0).to(latents.device) - img_ids = img_ids_pad.unsqueeze(0).to(latents.device) - if self.do_classifier_free_guidance: - img_sizes = img_sizes.repeat(2 * B, 1) - img_ids = img_ids.repeat(2 * B, 1, 1) - else: - img_sizes = img_ids = None - # 5. Prepare timesteps mu = calculate_shift(self.transformer.max_seq) scheduler_kwargs = {"mu": mu} @@ -774,10 +954,9 @@ def __call__( noise_pred = self.transformer( hidden_states=latent_model_input, timesteps=timestep, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states_t5=prompt_embeds_t5, + encoder_hidden_states_llama3=prompt_embeds_llama3, pooled_embeds=pooled_prompt_embeds, - img_sizes=img_sizes, - img_ids=img_ids, return_dict=False, )[0] noise_pred = -noise_pred @@ -803,8 +982,9 @@ def __call__( callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5) + prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/tests/pipelines/hidream/test_pipeline_hidream.py b/tests/pipelines/hidream/test_pipeline_hidream.py index 597a20216882..525e29eaa6a2 100644 --- a/tests/pipelines/hidream/test_pipeline_hidream.py +++ b/tests/pipelines/hidream/test_pipeline_hidream.py @@ -43,7 +43,7 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = HiDreamImagePipeline - params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS