Skip to content

[LoRA] add LoRA support to HiDream and fine-tuning script #11281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 102 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
b8b6465
initial commit
linoytsaban Apr 8, 2025
d728168
initial commit
linoytsaban Apr 8, 2025
0fa0993
initial commit
linoytsaban Apr 10, 2025
5ecf0ed
initial commit
linoytsaban Apr 10, 2025
911c30e
initial commit
linoytsaban Apr 10, 2025
b7fffee
initial commit
linoytsaban Apr 10, 2025
fde5eeb
Merge branch 'main' into hi-dream
linoytsaban Apr 10, 2025
02de3ce
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 14, 2025
4e08343
Merge branch 'main' into hi-dream
linoytsaban Apr 14, 2025
5257b46
move prompt embeds, pooled embeds outside
linoytsaban Apr 14, 2025
e9b4ad2
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 14, 2025
677bab1
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 14, 2025
dafc5fe
Merge branch 'main' into hi-dream
linoytsaban Apr 14, 2025
139568a
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
linoytsaban Apr 14, 2025
31aa0a2
fix import
linoytsaban Apr 14, 2025
de1654a
fix import and tokenizer 4, text encoder 4 loading
linoytsaban Apr 14, 2025
33385c9
te
linoytsaban Apr 14, 2025
d993e16
prompt embeds
linoytsaban Apr 14, 2025
c296b6f
fix naming
linoytsaban Apr 14, 2025
aa6b6e2
shapes
linoytsaban Apr 14, 2025
ecc1c18
initial commit to add HiDreamImageLoraLoaderMixin
linoytsaban Apr 14, 2025
c439c89
fix init
linoytsaban Apr 14, 2025
fc97a54
Merge branch 'huggingface:main' into hi-dream
linoytsaban Apr 14, 2025
22e9ae8
add tests
linoytsaban Apr 14, 2025
8e2e1f1
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
linoytsaban Apr 14, 2025
3653dcc
loader
linoytsaban Apr 14, 2025
32567bb
Merge branch 'main' into hi-dream
linoytsaban Apr 14, 2025
26f289d
Merge branch 'main' into hi-dream
linoytsaban Apr 14, 2025
fcf6eaa
fix model input
linoytsaban Apr 15, 2025
0fdc7dd
add code example to readme
linoytsaban Apr 15, 2025
759d204
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
linoytsaban Apr 15, 2025
62f2f15
fix default max length of text encoders
linoytsaban Apr 15, 2025
3cb1a4c
prints
linoytsaban Apr 15, 2025
82bcd44
nullify training cond in unpatchify for temp fix to incompatible shap…
linoytsaban Apr 15, 2025
e956980
Merge branch 'main' into hi-dream
linoytsaban Apr 15, 2025
764bce6
Merge branch 'main' into hi-dream
sayakpaul Apr 16, 2025
75aa8bd
smol fix
linoytsaban Apr 16, 2025
47e861f
unpatchify
linoytsaban Apr 16, 2025
a461d38
unpatchify
linoytsaban Apr 16, 2025
2a1336c
Merge branch 'main' into hi-dream
linoytsaban Apr 16, 2025
1b20d15
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
linoytsaban Apr 16, 2025
b31b595
fix validation
linoytsaban Apr 16, 2025
640bead
Merge branch 'main' into hi-dream
linoytsaban Apr 16, 2025
466c9c0
flip pred and loss
linoytsaban Apr 17, 2025
04be03a
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
linoytsaban Apr 17, 2025
baa9784
Merge branch 'main' into hi-dream
linoytsaban Apr 17, 2025
efb22e2
Merge branch 'main' into hi-dream
linoytsaban Apr 17, 2025
6043d9d
fix shift!!!
linoytsaban Apr 17, 2025
0a2adfd
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
linoytsaban Apr 17, 2025
92306d9
Merge branch 'main' into hi-dream
linoytsaban Apr 17, 2025
13e6f0d
revert unpatchify changes (for now)
linoytsaban Apr 17, 2025
c3a4047
Merge branch 'main' into hi-dream
linoytsaban Apr 18, 2025
ae39434
smol fix
linoytsaban Apr 18, 2025
c8932ed
Apply style fixes
github-actions[bot] Apr 18, 2025
c8ac7d5
workaround moe training
linoytsaban Apr 18, 2025
bf7ace6
workaround moe training
linoytsaban Apr 18, 2025
2fcb17d
remove prints
linoytsaban Apr 18, 2025
80f13be
to reduce some memory, keep vae in `weight_dtype` same as we have for…
linoytsaban Apr 18, 2025
b8039c9
refactor to align with HiDream refactor
linoytsaban Apr 18, 2025
c331597
refactor to align with HiDream refactor
linoytsaban Apr 18, 2025
c32cccc
refactor to align with HiDream refactor
linoytsaban Apr 18, 2025
6e070b8
add support for cpu offloading of text encoders
linoytsaban Apr 18, 2025
d77e42a
Apply style fixes
github-actions[bot] Apr 18, 2025
5c8c339
adjust lr and rank for train example
linoytsaban Apr 18, 2025
abfb389
fix copies
linoytsaban Apr 18, 2025
a5fe6be
Apply style fixes
github-actions[bot] Apr 18, 2025
d562776
Merge branch 'main' into hi-dream
linoytsaban Apr 19, 2025
2798d40
update README
linoytsaban Apr 19, 2025
ab960c2
update README
linoytsaban Apr 19, 2025
52d9421
update README
linoytsaban Apr 19, 2025
fc5eb48
fix license
linoytsaban Apr 19, 2025
a012914
keep prompt2,3,4 as None in validation
linoytsaban Apr 19, 2025
d5b9ecc
remove reverse ode comment
linoytsaban Apr 19, 2025
f04a13a
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 19, 2025
0b75081
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 19, 2025
4db988f
vae offload change
linoytsaban Apr 19, 2025
13192a3
fix text encoder offloading
linoytsaban Apr 19, 2025
408dfdb
Apply style fixes
github-actions[bot] Apr 19, 2025
3383446
cleaner to_kwargs
linoytsaban Apr 20, 2025
73ab201
fix module name in copied from
linoytsaban Apr 20, 2025
62019cc
Merge branch 'main' into hi-dream
linoytsaban Apr 21, 2025
a07ee59
add requirements
linoytsaban Apr 21, 2025
a751cc2
fix offloading
linoytsaban Apr 21, 2025
120b821
fix offloading
linoytsaban Apr 21, 2025
9d71d3b
fix offloading
linoytsaban Apr 21, 2025
2798d77
update transformers version in reqs
linoytsaban Apr 21, 2025
363d29b
try AutoTokenizer
linoytsaban Apr 21, 2025
8f24e8c
try AutoTokenizer
linoytsaban Apr 21, 2025
b31bdf0
Apply style fixes
github-actions[bot] Apr 21, 2025
6c77651
empty commit
linoytsaban Apr 21, 2025
246978e
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
linoytsaban Apr 21, 2025
9b6ef43
Delete tests/lora/test_lora_layers_hidream.py
linoytsaban Apr 21, 2025
fb3ac74
change tokenizer_4 to load with AutoTokenizer as well
linoytsaban Apr 21, 2025
82a4037
make text_encoder_four and tokenizer_four configurable
linoytsaban Apr 21, 2025
9de10cb
save model card
linoytsaban Apr 21, 2025
418f6a3
save model card
linoytsaban Apr 21, 2025
36c1ada
revert T5
linoytsaban Apr 21, 2025
bd399b1
fix test
linoytsaban Apr 21, 2025
a8e1a0b
Merge branch 'main' into hi-dream
linoytsaban Apr 21, 2025
bd275c0
Merge branch 'main' into hi-dream
linoytsaban Apr 21, 2025
e49be97
Merge branch 'main' into hi-dream
sayakpaul Apr 22, 2025
ed97dba
remove non diffusers lumina2 conversion
linoytsaban Apr 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

<Tip>
Expand Down Expand Up @@ -91,6 +92,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse

[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin

## HiDreamImageLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin

## LoraBaseMixin

[[autodoc]] loaders.lora_base.LoraBaseMixin
133 changes: 133 additions & 0 deletions examples/dreambooth/README_hidream.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# DreamBooth training example for HiDream Image

[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.

The `train_dreambooth_lora_hidream.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/).


This will also allow us to push the trained model parameters to the Hugging Face Hub platform.

## Running locally with PyTorch

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```

Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sana.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

Or for a default accelerate configuration without answering questions about your environment

```bash
accelerate config default
```

Or if your environment doesn't support an interactive shell (e.g., a notebook)

```python
from accelerate.utils import write_basic_config
write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.


### Dog toy example

Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.

Let's first download it locally:

```python
from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```

This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.

Now, we can launch training using:
> [!NOTE]
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
> 8-bit Adam optimizer, latent caching, offloading, no validation.
> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image)
> text embeddings are pre-computed to save memory.

```bash
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-hidream-lora"

accelerate launch train_dreambooth_lora_hidream.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--rank=16 \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--cache_latents \
--gradient_checkpointing \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```

For using `push_to_hub`, make you're logged into your Hugging Face account:

```bash
huggingface-cli login
```

To better track our training experiments, we're using the following flags in the command above:

* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.

## Notes

Additionally, we welcome you to explore the following CLI arguments:

* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
* `--rank`: The rank of the LoRA layers. The higher the rank, the more parameters are trained. The default is 16.

We provide several options for optimizing memory optimization:

* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done.

Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
8 changes: 8 additions & 0 deletions examples/dreambooth/requirements_hidream.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
accelerate>=1.4.0
torchvision
transformers>=4.50.0
ftfy
tensorboard
Jinja2
peft>=0.14.0
sentencepiece
220 changes: 220 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_hidream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile

import safetensors


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe"
text_encoder_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py"
transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k"

def test_dreambooth_lora_hidream(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# `self.transformer_layer_type` should be in the state dict.
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)

def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})

resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 16
""".split()

resume_run_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + resume_run_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
Loading