diff --git a/docs/diffusers/imgs/README.md b/docs/diffusers/imgs/README.md
new file mode 100644
index 0000000000..69272da5fd
--- /dev/null
+++ b/docs/diffusers/imgs/README.md
@@ -0,0 +1,3 @@
+### Image Credits
+
+The images in this folder are taken from the [Hugging Face Diffusers repository](https://github.com/huggingface/diffusers/tree/main/docs/source/en/imgs) and are subject to the Apache 2.0 license of the Diffusers project.
diff --git a/docs/diffusers/imgs/access_request.png b/docs/diffusers/imgs/access_request.png
new file mode 100644
index 0000000000..33c6abc88d
Binary files /dev/null and b/docs/diffusers/imgs/access_request.png differ
diff --git a/docs/diffusers/imgs/diffusers_library.jpg b/docs/diffusers/imgs/diffusers_library.jpg
new file mode 100644
index 0000000000..07ba9c6571
Binary files /dev/null and b/docs/diffusers/imgs/diffusers_library.jpg differ
diff --git a/examples/diffusers/cogvideox_factory/README.md b/examples/diffusers/cogvideox_factory/README.md
index baa030f457..5e5a031897 100644
--- a/examples/diffusers/cogvideox_factory/README.md
+++ b/examples/diffusers/cogvideox_factory/README.md
@@ -5,7 +5,8 @@
> 我们的开发和验证基于Ascend Atlas 800T A2硬件,相关环境如下:
> | mindspore | ascend driver | firmware | cann toolkit/kernel |
> |:----------:|:--------------:|:-----------:|:------------------:|
-> | 2.5 | 24.1.RC2 | 7.5.0.1.129 | 8.0.0.beta1 |
+> | 2.6.0 | 24.1.RC2 | 7.3.0.1.231 | 8.1.RC1 |
+> | 2.7.0 | 24.1.RC2 | 7.3.0.1.231 | 8.2.RC1 |
@@ -409,3 +410,7 @@ NODE_RANK="0"
当前训练脚本并不完全支持原仓代码的所有训练参数,详情参见[`args.py`](./scripts/args.py)中的`check_args()`。
其中一个主要的限制来自于CogVideoX模型中的[3D Causual VAE不支持静态图](https://gist.github.com/townwish4git/b6cd0d213b396eaedfb69b3abcd742da),这导致我们**不支持静态图模式下VAE参与训练**,因此在静态图模式下必须提前进行数据预处理以获取VAE-latents/text-encoder-embeddings cache。
+
+
+### 注意
+训练结束后若出现 `Exception ignored: OSError [Errno 9] Bad file descriptor`,仅为 Python 关闭时的提示,不影响训练结果;使用 Python 3.11,该提示不再出现。
diff --git a/examples/diffusers/cogvideox_factory/cogvideox/models/autoencoder_kl_cogvideox_sp.py b/examples/diffusers/cogvideox_factory/cogvideox/models/autoencoder_kl_cogvideox_sp.py
index c802b91eb8..acb0b98f73 100644
--- a/examples/diffusers/cogvideox_factory/cogvideox/models/autoencoder_kl_cogvideox_sp.py
+++ b/examples/diffusers/cogvideox_factory/cogvideox/models/autoencoder_kl_cogvideox_sp.py
@@ -31,7 +31,6 @@
from mindone.diffusers.models.layers_compat import pad
from mindone.diffusers.models.modeling_outputs import AutoencoderKLOutput
from mindone.diffusers.models.modeling_utils import ModelMixin
-from mindone.diffusers.models.normalization import GroupNorm
from mindone.diffusers.models.upsampling import CogVideoXUpsample3D
from mindone.diffusers.utils import logging
@@ -40,7 +39,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class GroupNorm_SP(GroupNorm):
+class GroupNorm_SP(mint.nn.GroupNorm):
def set_frame_group_size(self, frame_group_size):
self.frame_group_size = frame_group_size
diff --git a/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_lora.sh b/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_lora.sh
index f8f4069fd8..deebc39706 100644
--- a/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_lora.sh
+++ b/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_lora.sh
@@ -32,7 +32,7 @@ AMP_LEVEL=O2
DATA_ROOT="preprocessed-dataset"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"
-MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5b"
+MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5B"
H=768
W=1360
F=77
diff --git a/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_sft.sh b/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_sft.sh
index af86ded8cc..7935dc8b50 100644
--- a/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_sft.sh
+++ b/examples/diffusers/cogvideox_factory/scripts/train_text_to_video_sft.sh
@@ -40,7 +40,7 @@ DEEPSPEED_ZERO_STAGE=3
DATA_ROOT="preprocessed-dataset"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"
-MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5b"
+MODEL_NAME_OR_PATH="THUDM/CogVideoX1.5-5B"
H=768
W=1360
F=77
diff --git a/examples/diffusers/cogview/README.md b/examples/diffusers/cogview/README.md
index 9d0a3f416f..90c7bc2b95 100644
--- a/examples/diffusers/cogview/README.md
+++ b/examples/diffusers/cogview/README.md
@@ -29,7 +29,7 @@ cd mindone
pip install -e .
# NOTE: transformers requires >=4.46.0
-cd examples/cogview
+cd examples/diffusers/cogview
```
diff --git a/examples/diffusers/controlnet/test_controlnet.py b/examples/diffusers/controlnet/test_controlnet.py
new file mode 100644
index 0000000000..9fab5b3327
--- /dev/null
+++ b/examples/diffusers/controlnet/test_controlnet.py
@@ -0,0 +1,146 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class ControlNet(ExamplesTests):
+ def test_controlnet_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=6
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6"},
+ )
+
+ resume_run_args = f"""
+ examples/diffusers/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-6
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+
+
+class ControlNetSDXL(ExamplesTests):
+ def test_controlnet_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/controlnet/train_controlnet_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+
+
+class ControlNetflux(ExamplesTests):
+ def test_controlnet_flux(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/controlnet/train_controlnet_flux.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
+ --output_dir={tmpdir}
+ --dataset_name=hf-internal-testing/fill10
+ --conditioning_image_column=conditioning_image
+ --image_column=image
+ --caption_column=text
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ --num_double_layers=1
+ --num_single_layers=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
diff --git a/examples/diffusers/controlnet/train_controlnet.py b/examples/diffusers/controlnet/train_controlnet.py
index fd6be61215..ea18d066b1 100644
--- a/examples/diffusers/controlnet/train_controlnet.py
+++ b/examples/diffusers/controlnet/train_controlnet.py
@@ -879,8 +879,8 @@ def __len__(self):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -939,8 +939,7 @@ def __len__(self):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
diff --git a/examples/diffusers/controlnet/train_controlnet_flux.py b/examples/diffusers/controlnet/train_controlnet_flux.py
index 7d6c01d466..ab6988ea48 100644
--- a/examples/diffusers/controlnet/train_controlnet_flux.py
+++ b/examples/diffusers/controlnet/train_controlnet_flux.py
@@ -35,7 +35,7 @@
from mindspore.dataset import GeneratorDataset, transforms, vision
from mindone.diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel
-from mindone.diffusers.models.controlnet_flux import FluxControlNetModel
+from mindone.diffusers.models.controlnets.controlnet_flux import FluxControlNetModel
from mindone.diffusers.models.layers_compat import set_amp_strategy
from mindone.diffusers.optimization import get_scheduler
from mindone.diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
diff --git a/examples/diffusers/controlnet/train_controlnet_sdxl.py b/examples/diffusers/controlnet/train_controlnet_sdxl.py
index 572d751bd0..32aca96cb3 100644
--- a/examples/diffusers/controlnet/train_controlnet_sdxl.py
+++ b/examples/diffusers/controlnet/train_controlnet_sdxl.py
@@ -990,8 +990,8 @@ def __len__(self):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -1050,8 +1050,7 @@ def __len__(self):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
diff --git a/examples/diffusers/dreambooth/test_dreambooth.py b/examples/diffusers/dreambooth/test_dreambooth.py
new file mode 100644
index 0000000000..7b033490bf
--- /dev/null
+++ b/examples/diffusers/dreambooth/test_dreambooth.py
@@ -0,0 +1,236 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from mindone.diffusers import DiffusionPipeline, UNet2DConditionModel
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBooth(ExamplesTests):
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_if(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --revision refs/pr/1
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(
+ pretrained_model_name_or_path, unet=unet, safety_checker=None, revision="refs/pr/4"
+ )
+ pipe(instance_prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=1)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --instance_data_dir=docs/diffusers/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --instance_data_dir=docs/diffusers/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ resume_run_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --instance_data_dir=docs/diffusers/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/examples/diffusers/dreambooth/test_dreambooth_lora.py b/examples/diffusers/dreambooth/test_dreambooth_lora.py
new file mode 100644
index 0000000000..271883484b
--- /dev/null
+++ b/examples/diffusers/dreambooth/test_dreambooth_lora.py
@@ -0,0 +1,360 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import mindspore as ms
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+from mindone.diffusers import DiffusionPipeline # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRA(ExamplesTests):
+ def test_dreambooth_lora(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --train_text_encoder
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # check `text_encoder` is present at all.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ keys = lora_state_dict.keys()
+ is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_text_encoder_present)
+
+ # the names of the keys of the state dict should either start with `unet`
+ # or `text_encoder`.
+ is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_correct_naming)
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --instance_data_dir=docs/diffusers/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_if_model(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --revision refs/pr/1
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+
+class DreamBoothLoRASDXL(ExamplesTests):
+ def test_dreambooth_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --revision refs/pr/2
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path, revision="refs/pr/2")
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --revision refs/pr/2
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --train_text_encoder
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path, revision="refs/pr/2")
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
diff --git a/examples/diffusers/dreambooth/test_dreambooth_lora_edm.py b/examples/diffusers/dreambooth/test_dreambooth_lora_edm.py
new file mode 100644
index 0000000000..2e77ed9b28
--- /dev/null
+++ b/examples/diffusers/dreambooth/test_dreambooth_lora_edm.py
@@ -0,0 +1,104 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import mindspore as ms
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRASDXLWithEDM(ExamplesTests):
+ def test_dreambooth_lora_sdxl_with_edm(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --do_edm_style_training
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_playground(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
+ --instance_data_dir docs/diffusers/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
diff --git a/examples/diffusers/dreambooth/test_dreambooth_lora_flux.py b/examples/diffusers/dreambooth/test_dreambooth_lora_flux.py
new file mode 100644
index 0000000000..73db595c3b
--- /dev/null
+++ b/examples/diffusers/dreambooth/test_dreambooth_lora_flux.py
@@ -0,0 +1,132 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import mindspore as ms
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRAFlux(ExamplesTests):
+ instance_data_dir = "docs/diffusers/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
+ script_path = "train_dreambooth_lora_flux.py"
+ transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
+
+ def test_dreambooth_lora_flux(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_text_encoder_flux(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ starts_with_expected_prefix = all(
+ (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_expected_prefix)
+
+ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
diff --git a/examples/diffusers/dreambooth/test_dreambooth_lora_sd3.py b/examples/diffusers/dreambooth/test_dreambooth_lora_sd3.py
new file mode 100644
index 0000000000..e5545dfeeb
--- /dev/null
+++ b/examples/diffusers/dreambooth/test_dreambooth_lora_sd3.py
@@ -0,0 +1,101 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import mindspore as ms
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRASD3(ExamplesTests):
+ instance_data_dir = "docs/diffusers/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
+ script_path = "train_dreambooth_lora_sd3.py"
+
+ transformer_block_idx = 0
+ layer_type = "attn.to_k"
+
+ def test_dreambooth_lora_sd3(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --mixed_precision fp16
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --mixed_precision fp16
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
diff --git a/examples/diffusers/dreambooth/test_dreambooth_sd3.py b/examples/diffusers/dreambooth/test_dreambooth_sd3.py
new file mode 100644
index 0000000000..9edc34c896
--- /dev/null
+++ b/examples/diffusers/dreambooth/test_dreambooth_sd3.py
@@ -0,0 +1,124 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+from mindone.diffusers import DiffusionPipeline, SD3Transformer2DModel
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothSD3(ExamplesTests):
+ instance_data_dir = "docs/diffusers/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
+ script_path = "train_dreambooth_sd3.py"
+
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
+ pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/dreambooth/{self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
diff --git a/examples/diffusers/dreambooth/train_dreambooth.py b/examples/diffusers/dreambooth/train_dreambooth.py
index e50e9bf775..333ccbda6b 100644
--- a/examples/diffusers/dreambooth/train_dreambooth.py
+++ b/examples/diffusers/dreambooth/train_dreambooth.py
@@ -1057,8 +1057,8 @@ def compute_text_embeddings(prompt):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -1119,8 +1119,7 @@ def compute_text_embeddings(prompt):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
diff --git a/examples/diffusers/test_examples_utils.py b/examples/diffusers/test_examples_utils.py
new file mode 100644
index 0000000000..034f5f9953
--- /dev/null
+++ b/examples/diffusers/test_examples_utils.py
@@ -0,0 +1,62 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+import subprocess
+import tempfile
+import unittest
+from typing import List
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occurred while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+class ExamplesTests(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._tmpdir = tempfile.mkdtemp()
+ cls._launch_args = ["python"]
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ shutil.rmtree(cls._tmpdir)
+
+ def run_example(self, script_path: str, args: List[str] = None, return_stdout: bool = False):
+ """Run a Python example script directly."""
+ if args is None:
+ args = []
+ command = self._launch_args + [script_path] + args
+ return run_command(command, return_stdout=return_stdout)
diff --git a/examples/diffusers/text_to_image/test_text_to_image.py b/examples/diffusers/text_to_image/test_text_to_image.py
new file mode 100644
index 0000000000..e8a76af0dd
--- /dev/null
+++ b/examples/diffusers/text_to_image/test_text_to_image.py
@@ -0,0 +1,290 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from mindone.diffusers import DiffusionPipeline, UNet2DConditionModel # noqa: E402
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextToImage(ExamplesTests):
+ def test_text_to_image(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_text_to_image_checkpointing(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(
+ pretrained_model_name_or_path, unet=unet, safety_checker=None, revision="refs/pr/4"
+ )
+ pipe(prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 2 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-5"},
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # resume and we should try to checkpoint at 6, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 8
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8"},
+ )
+
+
+class TextToImageSDXL(ExamplesTests):
+ def test_text_to_image_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
diff --git a/examples/diffusers/text_to_image/test_text_to_image_lora.py b/examples/diffusers/text_to_image/test_text_to_image_lora.py
new file mode 100644
index 0000000000..41fff43c39
--- /dev/null
+++ b/examples/diffusers/text_to_image/test_text_to_image_lora.py
@@ -0,0 +1,311 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import mindspore as ms
+
+from mindone.diffusers import DiffusionPipeline # noqa: E402
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextToImageLoRA(ExamplesTests):
+ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --revision refs/pr/2
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path, revision="refs/pr/2")
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, revision="refs/pr/4"
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, revision="refs/pr/4"
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # resume and we should try to checkpoint at 6, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --revision refs/pr/4
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 8
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, revision="refs/pr/4"
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8"},
+ )
+
+
+class TextToImageLoRASDXL(ExamplesTests):
+ def test_text_to_image_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ def test_text_to_image_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --revision refs/pr/2
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = ms.load_checkpoint(
+ os.path.join(tmpdir, "pytorch_lora_weights.safetensors"), format="safetensors"
+ )
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/diffusers/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --revision refs/pr/2
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --train_text_encoder
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path, revision="refs/pr/2")
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
diff --git a/examples/diffusers/text_to_image/train_text_to_image.py b/examples/diffusers/text_to_image/train_text_to_image.py
index dcbcfa556d..31a3aece5a 100644
--- a/examples/diffusers/text_to_image/train_text_to_image.py
+++ b/examples/diffusers/text_to_image/train_text_to_image.py
@@ -728,8 +728,8 @@ def __len__(self):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file))
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"))
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -798,8 +798,7 @@ def __len__(self):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.numpy().item(), "lr": optimizer.get_lr().numpy().item()}
diff --git a/examples/diffusers/text_to_image/train_text_to_image_lora.py b/examples/diffusers/text_to_image/train_text_to_image_lora.py
index 29adeeff89..687758e7a6 100644
--- a/examples/diffusers/text_to_image/train_text_to_image_lora.py
+++ b/examples/diffusers/text_to_image/train_text_to_image_lora.py
@@ -748,8 +748,8 @@ def __len__(self):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file), strict_load=True)
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"), strict_load=True)
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -817,8 +817,7 @@ def __len__(self):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
StableDiffusionPipeline.save_lora_weights(
diff --git a/examples/diffusers/text_to_image/train_text_to_image_sdxl.py b/examples/diffusers/text_to_image/train_text_to_image_sdxl.py
index 9e772da1b7..88ae7cd6b8 100644
--- a/examples/diffusers/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/diffusers/text_to_image/train_text_to_image_sdxl.py
@@ -941,8 +941,8 @@ def __len__(self):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file))
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"))
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -1012,8 +1012,7 @@ def __len__(self):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.numpy().item(), "lr": optimizer.get_lr().numpy().item()}
diff --git a/examples/diffusers/textual_inversion/test_textual_inversion.py b/examples/diffusers/textual_inversion/test_textual_inversion.py
new file mode 100644
index 0000000000..d5ddb40930
--- /dev/null
+++ b/examples/diffusers/textual_inversion/test_textual_inversion.py
@@ -0,0 +1,156 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextualInversion(ExamplesTests):
+ def test_textual_inversion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
+
+ def test_textual_inversion_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
+
+ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2"},
+ )
+
+ resume_run_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --revision refs/pr/4
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
diff --git a/examples/diffusers/textual_inversion/test_textual_inversion_sdxl.py b/examples/diffusers/textual_inversion/test_textual_inversion_sdxl.py
new file mode 100644
index 0000000000..4f7c5bbcfc
--- /dev/null
+++ b/examples/diffusers/textual_inversion/test_textual_inversion_sdxl.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextualInversionSdxl(ExamplesTests):
+ def test_textual_inversion_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
+
+ def test_textual_inversion_sdxl_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
+
+ def test_textual_inversion_sdxl_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2"},
+ )
+
+ resume_run_args = f"""
+ examples/diffusers/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/diffusers/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
diff --git a/examples/diffusers/textual_inversion/textual_inversion.py b/examples/diffusers/textual_inversion/textual_inversion.py
index 6483f01487..da849f6249 100644
--- a/examples/diffusers/textual_inversion/textual_inversion.py
+++ b/examples/diffusers/textual_inversion/textual_inversion.py
@@ -789,8 +789,8 @@ def freeze_params(m: nn.Cell):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file))
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"))
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -876,8 +876,7 @@ def freeze_params(m: nn.Cell):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
diff --git a/examples/diffusers/textual_inversion/textual_inversion_sdxl.py b/examples/diffusers/textual_inversion/textual_inversion_sdxl.py
index cabb5af775..7a86a475d5 100644
--- a/examples/diffusers/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/diffusers/textual_inversion/textual_inversion_sdxl.py
@@ -784,8 +784,8 @@ def freeze_params(m: nn.Cell):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file))
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"))
global_step = int(path.split("-")[1])
initial_global_step = global_step
@@ -869,8 +869,7 @@ def freeze_params(m: nn.Cell):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
diff --git a/examples/diffusers/unconditional_image_generation/test_unconditional.py b/examples/diffusers/unconditional_image_generation/test_unconditional.py
new file mode 100644
index 0000000000..1cb1e6c8ef
--- /dev/null
+++ b/examples/diffusers/unconditional_image_generation/test_unconditional.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+sys.path.append("..")
+from examples.diffusers.test_examples_utils import ExamplesTests, run_command # noqa: E402
+
+ExamplesTests._launch_args = ["python"]
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class Unconditional(ExamplesTests):
+ def test_train_unconditional(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/diffusers/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 2
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ """.split()
+
+ run_command(self._launch_args + test_args, return_stdout=True)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_unconditional_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/diffusers/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/diffusers/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 1
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6"},
+ )
+
+ resume_run_args = f"""
+ examples/diffusers/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 2
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 1
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --resume_from_checkpoint=checkpoint-6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-10", "checkpoint-12"},
+ )
diff --git a/examples/diffusers/unconditional_image_generation/train_unconditional.py b/examples/diffusers/unconditional_image_generation/train_unconditional.py
index b4888d6970..aa0d7e7576 100644
--- a/examples/diffusers/unconditional_image_generation/train_unconditional.py
+++ b/examples/diffusers/unconditional_image_generation/train_unconditional.py
@@ -460,8 +460,8 @@ def __len__(self):
if is_master(args):
logger.info(f"Resuming from checkpoint {path}")
# TODO: load optimizer & grad scaler etc. like accelerator.load_state
- input_model_file = os.path.join(args.output_dir, path, "pytorch_model.ckpt")
- ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file))
+ input_model_file = os.path.join(args.output_dir, path, "unet/diffusion_pytorch_model.safetensors")
+ ms.load_param_into_net(unet, ms.load_checkpoint(input_model_file, format="safetensors"))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
@@ -513,8 +513,7 @@ def __len__(self):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# TODO: save optimizer & grad scaler etc. like accelerator.save_state
os.makedirs(save_path, exist_ok=True)
- output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
- ms.save_checkpoint(unet, output_model_file)
+ unet.save_pretrained(os.path.join(save_path, "unet"))
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.numpy().item(), "lr": optimizer.get_lr().numpy().item(), "step": global_step}
diff --git a/mindone/diffusers/models/transformers/transformer_flux.py b/mindone/diffusers/models/transformers/transformer_flux.py
index 17dda2b34e..5151c5afe5 100644
--- a/mindone/diffusers/models/transformers/transformer_flux.py
+++ b/mindone/diffusers/models/transformers/transformer_flux.py
@@ -112,11 +112,12 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
+ # mindspore not support split_with_sizes and mindspore 2.6 jit not support the format of split_size as list
# encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
# [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
# )
encoder_hidden_states, hidden_states = hidden_states.split(
- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ (encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]), dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
@@ -204,11 +205,12 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
+ # mindspore not support split_with_sizes and mindspore 2.6 jit not support the format of split_size as list
# encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
# [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
# )
encoder_hidden_states, hidden_states = hidden_states.split(
- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ (encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]), dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py
index 2141e8d134..ca18c93474 100644
--- a/mindone/transformers/modeling_utils.py
+++ b/mindone/transformers/modeling_utils.py
@@ -1664,6 +1664,10 @@ def save_pretrained(
# we currently don't use this setting automatically, but may start to use with v5
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = repr(dtype).split(".")[1]
+ model_to_save.config.mindspore_dtype = repr(dtype).split(".")[1]
+ for sub in ("text_config", "vision_config"):
+ if hasattr(model_to_save.config, sub):
+ getattr(model_to_save.config, sub).mindspore_dtype = repr(dtype).split(".")[1]
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]