Skip to content

[Hi Dream] follow-up #11296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

[Hi Dream] follow-up #11296

wants to merge 18 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import (
    UniPCMultistepScheduler,
    HiDreamImagePipeline,
    HiDreamImageTransformer2DModel,
)


# Get more detailed memory stats
def print_detailed_memory(step_name):
    print(f"\n=== CUDA Memory Stats {step_name} ===")
    torch.cuda.reset_peak_memory_stats()
    print(f"Current allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Max allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
    print(f"Current reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"Max reserved: {torch.cuda.max_memory_reserved() / 1024**3:.2f} GB")


device = torch.device("cuda:0")

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Initial memory usage
print_detailed_memory("start")


llama_repo = "unsloth/Meta-Llama-3.1-8B-Instruct"
# llama_repo = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
    llama_repo,
)

text_encoder_4 = LlamaForCausalLM.from_pretrained(
    llama_repo,
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)



pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",
    scheduler=None,
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    transformer=None,
    vae=None,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()


print_detailed_memory("Before encode prompt")

with torch.no_grad():
    prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
        'A cat holding a sign that says "Hi-Dreams.ai".',
        negative_prompt="bad quality, low quality",
        do_classifier_free_guidance=True,
        device=device,
        dtype=torch.bfloat16,
    )

print_detailed_memory("After encode prompt")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bghira

This comment was marked as resolved.

@bghira

This comment was marked as duplicate.

@sayakpaul
Copy link
Member

The check_inputs() function could benefit from:

diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
index aa1de849e..ecca4548a 100644
--- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
+++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
@@ -545,12 +545,14 @@ class HiDreamImagePipeline(DiffusionPipeline):
             )
 
         if prompt_embeds is not None and negative_prompt_embeds is not None:
-            if prompt_embeds.shape != negative_prompt_embeds.shape:
-                raise ValueError(
-                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
-                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
-                    f" {negative_prompt_embeds.shape}."
-                )
+            if isinstance(prompt_embeds, list) and isinstance(negative_prompt_embeds, list):
+                for i, (p, n) in enumerate(zip(prompt_embeds, negative_prompt_embeds)):
+                    if p.shape != n.shape:
+                        raise ValueError(
+                            "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                            f" got: `prompt_embeds[{i}]` {p.shape} != `negative_prompt_embeds[{i}]`"
+                            f" {n.shape}."
+                        )
 
         if prompt_embeds is not None and pooled_prompt_embeds is None:
             raise ValueError(

This is required when the prompt embeds are passed.

@yiyixuxu
Copy link
Collaborator Author

ohh yes, I think we should have 2 prompt_embeds, (t5 and llama) like @bghira suggested, instead of list, in that case, we should accept them as inputs too

@sayakpaul
Copy link
Member

Oh indeed!

@bghira
Copy link
Contributor

bghira commented Apr 12, 2025

here's my MoEGate:

class MoEGate(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_routed_experts=4,
        num_activated_experts=2,
        aux_loss_alpha=0.01,
    ):
        super().__init__()
        self.top_k = num_activated_experts
        self.n_routed_experts = num_routed_experts

        self.scoring_func = "softmax"
        self.alpha = aux_loss_alpha
        self.seq_aux = False

        # topk selection algorithm
        self.norm_topk_prob = False
        self.gating_dim = embed_dim
        self.weight = nn.Parameter(
            torch.empty((self.n_routed_experts, self.gating_dim))
        )
        self.reset_parameters()

    def reset_parameters(self) -> None:
        import torch.nn.init as init

        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, hidden_states):
        bsz, seq_len, h = hidden_states.shape

        # Compute gating score
        hidden_states = hidden_states.view(-1, h)
        logits = F.linear(hidden_states, self.weight, None)
        if self.scoring_func == "softmax":
            scores = logits.softmax(dim=-1)
        else:
            raise NotImplementedError(
                f"insupportable scoring function for MoE gating: {self.scoring_func}"
            )

        # Select top-k experts
        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        # Norm gate to sum 1
        if self.top_k > 1 and self.norm_topk_prob:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator

        # Expert-level computation auxiliary loss with gradient checkpointing
        if self.training and self.alpha > 0.0:
            scores_for_aux = scores
            aux_topk = self.top_k
            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)

            if self.seq_aux:
                # Sequence-level auxiliary loss with gradient checkpointing
                def create_seq_aux_loss_fn(scores_view, idx, device):
                    def compute_seq_aux_loss():
                        ce = torch.zeros(bsz, self.n_routed_experts, device=device)
                        ce.scatter_add_(
                            1,
                            idx,
                            torch.ones(bsz, seq_len * aux_topk, device=device),
                        ).div_(seq_len * aux_topk / self.n_routed_experts)
                        return ce, scores_view
                    return compute_seq_aux_loss

                ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)

                ce, scores_view = torch.utils.checkpoint.checkpoint(
                    create_seq_aux_loss_fn(scores_for_seq_aux, topk_idx_for_aux_loss, hidden_states.device),
                    **ckpt_kwargs
                )

                aux_loss = (ce * scores_view.mean(dim=1)).sum(dim=1).mean() * self.alpha
            else:
                # Token-level auxiliary loss with gradient checkpointing
                def create_token_aux_loss_fn(scores_mean, idx, num_classes):
                    def compute_token_aux_loss():
                        mask_ce = F.one_hot(idx.view(-1), num_classes=num_classes)
                        ce = mask_ce.float().mean(0)
                        return ce, scores_mean
                    return compute_token_aux_loss

                ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                scores_mean = scores_for_aux.mean(0)

                ce, scores_mean = torch.utils.checkpoint.checkpoint(
                    create_token_aux_loss_fn(scores_mean, topk_idx_for_aux_loss, self.n_routed_experts),
                    **ckpt_kwargs
                )

                fi = ce * self.n_routed_experts
                aux_loss = (scores_mean * fi).sum() * self.alpha

                # Store for later use but detach to prevent memory leakage
                with torch.no_grad():
                    save_load_balancing_loss((aux_loss.detach(), scores_mean.detach(), fi.detach(), self.alpha))
        else:
            aux_loss = None

        return topk_idx, topk_weight, aux_loss

if there's no problems with this approach i think it could be adopted into Diffusers, because it greatly reduced the training memory usage (around 40G instead of 160G even with extensive gradient checkpointing)

@bghira
Copy link
Contributor

bghira commented Apr 12, 2025

there's an issue when trying to make image with pre-encoded embeds and moving text encoder to meta device. it complains about tensor device misalignment. however, we can set self._execution_device = self.transformer.device manually inside __call__ and it then works 🥳

@yiyixuxu yiyixuxu requested a review from a-r-r-o-w April 13, 2025 05:37
@yiyixuxu
Copy link
Collaborator Author

@a-r-r-o-w can you look into the comment here, if it makes sense, we can edit in a separate PR (since it is training related)
#11296 (comment)

@yiyixuxu
Copy link
Collaborator Author

@bghira
I think this is expected, pre-encoded embeds works out-of-box with enable_model_cpu_offloadthough - we put hooks on the models and they are moved to device when forward pass is called. Since you are using pre-encoded embeds, text_encoders are not used so they stay on cpu entire time, and it won't affect self._execution_device

there's an issue when trying to make image with pre-encoded embeds and moving text encoder to meta device. it complains about tensor device misalignment. however, we can set self._execution_device = self.transformer.device manually inside call and it then works 🥳

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one comment.

Thanks for the suggestions @bghira!

I'll look into rewriting the MoE related code. You could probably just make the MoEGate behave the same way as it'd do during inference. This simplification should work for LoRA-related training since most low-rank training is done on the QKVO layers.

The additional gradient checkpointing does not make sense to me in the MoEGate layer though. It saves tracking gradients through a b * s * top_k tensor (some constant x num_layers) times, and that region is not really compute intensive, so that should not be a whole lot of memory compared to what happens in other layers. Yet to test training so I'll try to understand this better.

Comment on lines 717 to 718
t5_encoder_hidden_states: torch.Tensor = None,
llama3_encoder_hidden_states: torch.Tensor = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename these to encoder_hidden_states_t5 and encoder_hidden_states_llama3 instead? This is so that it is consistent with our naming convention in things like HunyuanDiT/EasyAnimate

@@ -326,8 +326,10 @@ def encode_prompt(
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
negative_prompt_4: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
t5_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above for these occurences. It's just a nit in order to have some form of consistency in such cases

@bghira
Copy link
Contributor

bghira commented Apr 13, 2025

@bghira I think this is expected, pre-encoded embeds works out-of-box with enable_model_cpu_offloadthough - we put hooks on the models and they are moved to device when forward pass is called. Since you are using pre-encoded embeds, text_encoders are not used so they stay on cpu entire time, and it won't affect self._execution_device

there's an issue when trying to make image with pre-encoded embeds and moving text encoder to meta device. it complains about tensor device misalignment. however, we can set self._execution_device = self.transformer.device manually inside call and it then works 🥳

it works for Flux, SD3, PixArt, SDXL, StableDiffusion, and DeepFloyd, but somehow, HiDream is the only one who finds himself with meta for execution device. that's why i brought it up.

@bghira
Copy link
Contributor

bghira commented Apr 13, 2025

LGTM, just one comment.

Thanks for the suggestions @bghira!

I'll look into rewriting the MoE related code. You could probably just make the MoEGate behave the same way as it'd do during inference. This simplification should work for LoRA-related training since most low-rank training is done on the QKVO layers.

The additional gradient checkpointing does not make sense to me in the MoEGate layer though. It saves tracking gradients through a b * s * top_k tensor (some constant x num_layers) times, and that region is not really compute intensive, so that should not be a whole lot of memory compared to what happens in other layers. Yet to test training so I'll try to understand this better.

believe me, we hit >100G VRAM use without that change.

but we're not doing LoRA training for the most part, we're doing Lycoris algorithms which hit these issues easily.

@yiyixuxu yiyixuxu requested a review from asomoza April 14, 2025 20:43
@sayakpaul
Copy link
Member

@yiyixuxu sorry for the late request. Do you think the following block could be turned into a static method like we have in FluxPipeline?

if latents.shape[-2] != latents.shape[-1]:
B, C, H, W = latents.shape
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
img_ids = torch.zeros(pH, pW, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
img_ids = img_ids.reshape(pH * pW, -1)
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[: pH * pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
if self.do_classifier_free_guidance:
img_sizes = img_sizes.repeat(2 * B, 1)
img_ids = img_ids.repeat(2 * B, 1, 1)
else:
img_sizes = img_ids = None

This would help training quite a bit.

@ShuyUSTC
Copy link

ShuyUSTC commented Apr 16, 2025

We plan to update the scheduler config in our Hugging Face model repository. After this change, all HiDream-I1 pipelines (Full/Dev/Fast) can be initialized as follows:

import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import HiDreamImagePipeline

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)

pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",  # "HiDream-ai/HiDream-I1-Dev" | "HiDream-ai/HiDream-I1-Fast"
    scheduler=scheduler,
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    torch_dtype=torch.bfloat16,
)

There is no longer a need to explicitly initialize a scheduler outside of HiDreamImagePipeline.

Moreover, we update the scheduler of the Dev and Fast models by using FlowMatchLCMScheduler

For more details, please refer to the pull request:

https://huggingface.co/HiDream-ai/HiDream-I1-Full/discussions/19#67ff4f0cdf848a9f25c82e95
https://huggingface.co/HiDream-ai/HiDream-I1-Dev/discussions/5#67ff46fcad4d4fb4cd2bd8cc
https://huggingface.co/HiDream-ai/HiDream-I1-Fast/discussions/2#67ff511cd0d0b6a992592010

Would appreciate any feedback on this update.

@yiyixuxu
Copy link
Collaborator Author

@ShuyUSTC oh thanks, we can update our doc examples once you updated the repo!

@yiyixuxu
Copy link
Collaborator Author

@sayakpaul i moved it to transformers, not sure why one code path (img ids for non-squared images) is inside pipeline, and another code path (img ids for squared images) is prepared inside transformers

@sayakpaul
Copy link
Member

Thanks a lot!

@ShuyUSTC
Copy link

@ShuyUSTC oh thanks, we can update our doc examples once you updated the repo!

Thanks very much!

@ShuyUSTC
Copy link

We plan to update the scheduler config in our Hugging Face model repository. After this change, all HiDream-I1 pipelines (Full/Dev/Fast) can be initialized as follows:

import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import HiDreamImagePipeline

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)

pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",  # "HiDream-ai/HiDream-I1-Dev" | "HiDream-ai/HiDream-I1-Fast"
    scheduler=scheduler,
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    torch_dtype=torch.bfloat16,
)

There is no longer a need to explicitly initialize a scheduler outside of HiDreamImagePipeline.

Moreover, we update the scheduler of the Dev and Fast models by using FlowMatchLCMScheduler

For more details, please refer to the pull request:

https://huggingface.co/HiDream-ai/HiDream-I1-Full/discussions/19#67ff4f0cdf848a9f25c82e95 https://huggingface.co/HiDream-ai/HiDream-I1-Dev/discussions/5#67ff46fcad4d4fb4cd2bd8cc https://huggingface.co/HiDream-ai/HiDream-I1-Fast/discussions/2#67ff511cd0d0b6a992592010

Would appreciate any feedback on this update.

We have updated the huggingface model repository. Now one can inference with the following commands:

import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import HiDreamImagePipeline
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)

pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",  # "HiDream-ai/HiDream-I1-Dev" | "HiDream-ai/HiDream-I1-Fast"
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    torch_dtype=torch.bfloat16,
)

pipe = pipe.to('cuda')

image = pipe(
    'A cat holding a sign that says "HiDream.ai".',
    height=1024,
    width=1024,
    guidance_scale=5.0,  # 0.0 for Dev&Fast
    num_inference_steps=50,  # 28 for Dev and 16 for Fast
    generator=torch.Generator("cuda").manual_seed(0),
).images[0]
image.save("output.png")

@vladmandic
Copy link
Contributor

just imo, updating default scheduler to something that doesn't exist upstream in released diffusers is not the way to go as anyone that had working hidream yesterday will have immediate failure today without any clear explanation that they need to update to latest diffusers main codebase.

@bghira
Copy link
Contributor

bghira commented Apr 16, 2025

ah that explains why all of my users are like "why do you keep breaking HiDream", and I have to let them know, "that is not me".

@ShuyUSTC
Copy link

@vladmandic @bghira Thanks for the feedback! I'd like to clarify a few points regarding the scheduler update:

  1. This change does not alter the inference behavior of any HiDream-I1 models. The update is purely aimed at improving the pipeline's initialization experience — users can still manually initialize the scheduler externally, just like before. This update simply makes the default usage more streamlined for new users.

  2. The default schedulers for HiDream-I1-Fast and HiDream-I1-Dev have always been FlowMatchLCMScheduler. The reason we hadn't yet provided examples for them in the diffusers docs was that this scheduler wasn't available in the upstream diffusers library until recently.

  3. The FlashFlowMatchEulerDiscreteScheduler used in our GitHub repository is just a local implementation of FlowMatchLCMScheduler. Now that it's officially available in diffusers, we're aligning everything accordingly.

We understand the concern around breaking changes — but in this case, there’s no functional regression. Existing workflows that use externally defined schedulers will continue to work as expected.

Hopefully this clears up the confusion — and saves us all a few user support messages 😅

@bghira
Copy link
Contributor

bghira commented Apr 16, 2025

no, it literally complains when starting up that the scheduler is not available in Diffusers because the model config was updated in the repository, which is where the path to the scheduler is used from; the model config DID work before. now it doesn't. we didn't explicitly need a external scheduler defined before. those people receive error.

out[:, :, 0 : pH * pW] = hidden_states
hidden_states = out
# Patchify the input
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
Copy link

@YehLi YehLi Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    # Patchify the input
    if img_sizes is not None and img_ids is not None:
        B, C, S, _ = hidden_states.shape
        hidden_states_masks = torch.zeros((B, self.max_seq), dtype=hidden_states.dtype, device=hidden_states.device)
        for i, img_size in enumerate(img_sizes):
            hidden_states_masks[i, 0:img_size[0] * img_size[1]] = 1
        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(B, S, self.config.patch_size * self.config.patch_size * C)
    else:
        hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)

Keep img_sizes and img_ids since samples with different aspect ratios would be contained in the same batch during training

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I made it so that the patchify step can be skipped, i.e. for samples with different aspect ratios, user would have to prepare patchified hidden_states, hidden_states_mask, img_sizes and img_ids outside of the model and pass them as inputs, I think it is easier this way:

  1. this library is mainly for inference and fine tune, I think training with different aspect ratios is most for pre-trainning, no?
  2. they would need to prepare latents differently, we currently do not support that anyway

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. During pre-trainning, the input images are not cropped by default. Thus, it is recommended that users do not apply any cropping for images and prepare patchified hidden_states, img_sizes and img_ids outside of the model during fine-tuning. It is fine to finetune the model on multi-aspect bucketing like SDXL. The main purpose for img_sizes and img_ids is to provide users more options if they do not want to crop the images.
  2. They can prepare latents on dataloader with shape (B, C, HW / patch_size / patch_size, patch_sizepatch_size) and then pad them into the same length (B, C, S, patch_size*patch_size)

@yiyixuxu
Copy link
Collaborator Author

@vladmandic @bghira

sorry we broke for some of your users, I thought it'd be ok since it's only in main and in our doc example we initiate scheduler outside of the pipeline, it is a relatively ok time to switch if we ever want to switch the default.

Do you have suggestions for a better way?

@ShuyUSTC
Copy link

@bghira which model do you use? For HiDream-I1-Dev and HiDream-I1-Fast, the original scheduler in Hugging Face model repository, i.e., FlowMatchEulerDiscreteScheduler, is not the official scheduler. Our original inference code need to create scheduler manually. It is recommanded to use FlowMatchLCMScheduler recently merged into Diffusers main branch, which can be created within HiDreamImagePipeline automated now.

@bghira
Copy link
Contributor

bghira commented Apr 16, 2025

no it seemed to work as expected with Euler so i figured is ok. i will update everything on my end again and hope it does not break, are there more changes expected?

@yiyixuxu
Copy link
Collaborator Author

@bghira
i added a short deprecation cycle for all the changes in this PR (until next release)
can let you test before merge though

@bghira
Copy link
Contributor

bghira commented Apr 16, 2025

finetuning on multi-aspect bucketing is fine and common; but i'm confused what's meant by mixing shapes. making use of nested_tensor to purposely mix aspect bucketed data in a single batch?

@yiyixuxu
Copy link
Collaborator Author

@ShuyUSTC @YehLi
I found out that without CFG the it's not able to generate for squared images, but it's able to do so for non-squared images (not a good job). is this expecterd?

import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import (
    HiDreamImagePipeline,
    HiDreamImageTransformer2DModel,
)

device = "cuda:3"
testing_branch = "main"

llama_repo = "unsloth/Meta-Llama-3.1-8B-Instruct"
# llama_repo = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
    llama_repo,
)

text_encoder_4 = LlamaForCausalLM.from_pretrained(
    llama_repo,
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)

transformer = HiDreamImageTransformer2DModel.from_pretrained(
    "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16
)


pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload(device=device)



test_guidance = [0.0]
test_negative_prompt = [None]
test_width = [1024, 768]
num_inference_steps = 50

for guidance in test_guidance:
    for negative_prompt in test_negative_prompt:
        for width in test_width:
            print(f" --------------------------------------------------------------")
            print(f"testing guidance: {guidance}, negative_prompt: {negative_prompt}, width: {width}")
            image = pipe(
                'A cat holding a sign that says "Hi-Dreams.ai".',
                negative_prompt=negative_prompt,
                height=1024,
                width=width,
                guidance_scale=guidance,
                num_inference_steps=num_inference_steps,
                generator=torch.Generator("cuda").manual_seed(0),
            ).images[0]

            image.save(f"yiyi_test_2_out_{testing_branch}_{guidance}_{negative_prompt}_{width}.png")
            print(f"saved image to yiyi_test_2_out_{testing_branch}_{guidance}_{negative_prompt}_{width}.png")
            print(f" --------------------------------------------------------------")

1024 x 1024
yiyi_test_2_out_main_0 0_None_1024

1024 x 768
yiyi_test_2_out_main_0 0_None_768

@ShuyUSTC
Copy link

@yiyixuxu We recommend enabling cfg in HiDream-I1-Full model. The quality cannot be guaranteed without using cfg.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants