Skip to content

[LoRA] add LoRA support to HiDream and fine-tuning script #11281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 102 commits into from
Apr 22, 2025

Conversation

linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Apr 10, 2025

add lora training for Hi Dream Image

  • trains transformer layers (text encoder training not supported at the moment)
  • MoE training disabled for LoRA fine-tuning (potentially re-visited in the future)

memory optimizations:

  1. cpu offloading support for vae and text encoders
  2. latent caching
  3. vae loaded in mixed precision

example -

--pretrained_model_name_or_path=HiDream-ai/HiDream-I1-Full 
--dataset_name=Norod78/Yarn-art-style 
--output_dir=trained-hidream-lora 
--mixed_precision=bf16 
--instance_prompt="a dog, yarn art style" 
--caption_column=text 
--resolution=512 
--train_batch_size=1 
--gradient_accumulation_steps=1 
--optimizer=prodigy 
--rank=16 
--learning_rate=1.0 
--report_to=wandb 
--lr_scheduler=constant 
--lr_warmup_steps=0 
--max_train_steps=1000 
--validation_epochs=25 
--validation_prompt="yoda, yarn art style" 
--seed=0
Screenshot 2025-04-21 at 10 42 52

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 1434 to 1441
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hundred percent not a blocker for this PR. But we could allow users to drop out the other text encoders by zeroing those embeds.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in encode_prompt(),

prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]

Would prompt_embeds = prompt_embeds.to(accelerator.device) work here.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this!

I have left some clarification questions. LMK if they make sense.

@a-r-r-o-w a-r-r-o-w mentioned this pull request Apr 10, 2025
Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few comments, will review again when out of draft, thanks @linoytsaban!

@linoytsaban
Copy link
Collaborator Author

Hi @Nerogar!
For context - the self.training/ if training conditions in the HiDreamImageTransformer2DModel implementation come from the original implementation and refer to the case where aux loss is used and the experts are trained as well.
For LoRA training in this PR we currently do not support aux loss and MoE training, and hence we added
--fore_inference_output

if self.training and self.alpha > 0.0 and not self._force_inference_output:

Specifically we encountered an issue with unpatchify when initially working on this PR, with the resulting shape
of the output not as we'd expect, resulting with an error when trying to calculate the loss due to mismatch of the shapes between target and model_pred

@linoytsaban
Copy link
Collaborator Author

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

@linoytsaban linoytsaban requested a review from sayakpaul April 21, 2025 18:21
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@linoytsaban linoytsaban merged commit e30d3bf into huggingface:main Apr 22, 2025
29 checks passed
@linoytsaban linoytsaban deleted the hi-dream branch April 22, 2025 08:44
@vladmandic
Copy link
Contributor

follow-up issue ##11383

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants