Skip to content

Commit 18234b8

Browse files
committed
Merge remote-tracking branch 'upstream/master' into test_diffusers_example
2 parents e8afa2a + b896582 commit 18234b8

File tree

166 files changed

+8792
-762
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

166 files changed

+8792
-762
lines changed

docs/diffusers/imgs/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
### Image Credits
2+
3+
The images in this folder are taken from the [Hugging Face Diffusers repository](https://github.com/huggingface/diffusers/tree/main/docs/source/en/imgs) and are subject to the Apache 2.0 license of the Diffusers project.

examples/diffusers/cogvideox_factory/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,7 @@ NODE_RANK="0"
410410
当前训练脚本并不完全支持原仓代码的所有训练参数,详情参见[`args.py`](./scripts/args.py)中的`check_args()`。
411411

412412
其中一个主要的限制来自于CogVideoX模型中的[3D Causual VAE不支持静态图](https://gist.github.com/townwish4git/b6cd0d213b396eaedfb69b3abcd742da),这导致我们**不支持静态图模式下VAE参与训练**,因此在静态图模式下必须提前进行数据预处理以获取VAE-latents/text-encoder-embeddings cache。
413+
414+
415+
### 注意
416+
训练结束后若出现 `Exception ignored: OSError [Errno 9] Bad file descriptor`,仅为 Python 关闭时的提示,不影响训练结果;升级到 Python 3.11 即不再显示。

examples/mmada/models/modeling_utils.py

Lines changed: 363 additions & 122 deletions
Large diffs are not rendered by default.

examples/mmada/training/train_mmada_stage2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu, e
704704

705705
# Evaluate and save checkpoint at the end of training
706706
if rank_id == 0:
707-
save_checkpoint(model, config, global_step)
707+
save_checkpoint(model, config, global_step, uni_prompting)
708708

709709

710710
def visualize_predictions(

examples/opensora_pku/README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,6 @@ python tools/model_conversion/convert_wfvae.py --src LanguageBind/Open-Sora-Plan
151151
python tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py --src google/mt5-xxl/pytorch_model.bin --target google/mt5-xxl/model.safetensors --config google/mt5-xxl/config.json
152152
```
153153

154-
In addition, please merge the multiple .saftensors files under `any93x640x640/` into a merged checkpoint:
155-
```shell
156-
python tools/ckpt/merge_safetensors.py -i LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640/ -o LanguageBind/Open-Sora-Plan-v1.3.0/diffusion_pytorch_model.safetensors -f LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640/diffusion_pytorch_model.safetensors.index.json
157-
```
158154

159155
Once the checkpoint files have all been prepared, you can refer to the inference guidance below.
160156

examples/opensora_pku/opensora/dataset/t2v_datasets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import av
1717
import cv2
18-
import decord
1918
import numpy as np
2019
from opensora.dataset.transform import (
2120
add_aesthetic_notice_image,
@@ -101,6 +100,10 @@ def get_item(self, work_info):
101100

102101
class DecordDecoder(object):
103102
def __init__(self, url, num_threads=1):
103+
try:
104+
import decord
105+
except ImportError:
106+
raise ImportError("Please install decord!")
104107
self.num_threads = num_threads
105108
self.ctx = decord.cpu(0)
106109
self.reader = decord.VideoReader(url, ctx=self.ctx, num_threads=self.num_threads)

examples/opensora_pku/opensora/models/causalvideovae/model/modeling_videobase.py

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
11
# Adapted from
22
# https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/opensora/models/causalvideovae/model/modeling_videobase.py
33

4+
import copy
5+
import logging
6+
import os
7+
from typing import Dict, Optional, Union
8+
9+
from huggingface_hub import DDUFEntry
10+
from huggingface_hub.utils import validate_hf_hub_args
11+
412
import mindspore as ms
13+
from mindspore.nn.utils import no_init_parameters
514

6-
from mindone.diffusers import ModelMixin
15+
from mindone.diffusers import ModelMixin, __version__
716
from mindone.diffusers.configuration_utils import ConfigMixin
17+
from mindone.diffusers.models.model_loading_utils import _fetch_index_file, _fetch_index_file_legacy, load_state_dict
18+
from mindone.diffusers.models.modeling_utils import _convert_state_dict
19+
from mindone.diffusers.utils import (
20+
SAFETENSORS_WEIGHTS_NAME,
21+
WEIGHTS_NAME,
22+
_add_variant,
23+
_get_checkpoint_shard_files,
24+
_get_model_file,
25+
)
26+
27+
logger = logging.getLogger(__name__)
828

929

1030
class VideoBaseAE(ModelMixin, ConfigMixin):
@@ -23,3 +43,237 @@ def encode(self, x: ms.Tensor, *args, **kwargs):
2343

2444
def decode(self, encoding: ms.Tensor, *args, **kwargs):
2545
pass
46+
47+
@classmethod
48+
@validate_hf_hub_args
49+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
50+
# adapted from mindone.diffusers.models.modeling_utils.from_pretrained
51+
state_dict = kwargs.pop("state_dict", None) # additional key argument
52+
cache_dir = kwargs.pop("cache_dir", None)
53+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
54+
force_download = kwargs.pop("force_download", False)
55+
from_flax = kwargs.pop("from_flax", False)
56+
proxies = kwargs.pop("proxies", None)
57+
output_loading_info = kwargs.pop("output_loading_info", False)
58+
local_files_only = kwargs.pop("local_files_only", None)
59+
token = kwargs.pop("token", None)
60+
revision = kwargs.pop("revision", None)
61+
mindspore_dtype = kwargs.pop("mindspore_dtype", None)
62+
subfolder = kwargs.pop("subfolder", None)
63+
variant = kwargs.pop("variant", None)
64+
use_safetensors = kwargs.pop("use_safetensors", None)
65+
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
66+
disable_mmap = kwargs.pop("disable_mmap", False)
67+
68+
if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type):
69+
mindspore_dtype = ms.float32
70+
logger.warning(
71+
f"Passed `mindspore_dtype` {mindspore_dtype} is not a `ms.Type`. Defaulting to `ms.float32`."
72+
)
73+
74+
allow_pickle = False
75+
if use_safetensors is None:
76+
use_safetensors = True
77+
allow_pickle = True
78+
79+
user_agent = {
80+
"diffusers": __version__,
81+
"file_type": "model",
82+
"framework": "pytorch",
83+
}
84+
unused_kwargs = {}
85+
86+
# Load config if we don't provide a configuration
87+
config_path = pretrained_model_name_or_path
88+
89+
# load config
90+
config, unused_kwargs, commit_hash = cls.load_config(
91+
config_path,
92+
cache_dir=cache_dir,
93+
return_unused_kwargs=True,
94+
return_commit_hash=True,
95+
force_download=force_download,
96+
proxies=proxies,
97+
local_files_only=local_files_only,
98+
token=token,
99+
revision=revision,
100+
subfolder=subfolder,
101+
user_agent=user_agent,
102+
dduf_entries=dduf_entries,
103+
**kwargs,
104+
)
105+
# no in-place modification of the original config.
106+
config = copy.deepcopy(config)
107+
108+
# Check if `_keep_in_fp32_modules` is not None
109+
# use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and (
110+
# hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
111+
# )
112+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (mindspore_dtype == ms.float16)
113+
114+
if use_keep_in_fp32_modules:
115+
keep_in_fp32_modules = cls._keep_in_fp32_modules
116+
if not isinstance(keep_in_fp32_modules, list):
117+
keep_in_fp32_modules = [keep_in_fp32_modules]
118+
else:
119+
keep_in_fp32_modules = []
120+
121+
is_sharded = False
122+
resolved_model_file = None
123+
124+
# Determine if we're loading from a directory of sharded checkpoints.
125+
sharded_metadata = None
126+
index_file = None
127+
is_local = os.path.isdir(pretrained_model_name_or_path)
128+
index_file_kwargs = {
129+
"is_local": is_local,
130+
"pretrained_model_name_or_path": pretrained_model_name_or_path,
131+
"subfolder": subfolder or "",
132+
"use_safetensors": use_safetensors,
133+
"cache_dir": cache_dir,
134+
"variant": variant,
135+
"force_download": force_download,
136+
"proxies": proxies,
137+
"local_files_only": local_files_only,
138+
"token": token,
139+
"revision": revision,
140+
"user_agent": user_agent,
141+
"commit_hash": commit_hash,
142+
"dduf_entries": dduf_entries,
143+
}
144+
index_file = _fetch_index_file(**index_file_kwargs)
145+
# In case the index file was not found we still have to consider the legacy format.
146+
# this becomes applicable when the variant is not None.
147+
if variant is not None and (index_file is None or not os.path.exists(index_file)):
148+
index_file = _fetch_index_file_legacy(**index_file_kwargs)
149+
if index_file is not None and (dduf_entries or index_file.is_file()):
150+
is_sharded = True
151+
152+
# load model
153+
if from_flax:
154+
raise NotImplementedError("loading flax checkpoint in mindspore model is not yet supported.")
155+
else:
156+
# in the case it is sharded, we have already the index
157+
if is_sharded:
158+
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
159+
pretrained_model_name_or_path,
160+
index_file,
161+
cache_dir=cache_dir,
162+
proxies=proxies,
163+
local_files_only=local_files_only,
164+
token=token,
165+
user_agent=user_agent,
166+
revision=revision,
167+
subfolder=subfolder or "",
168+
dduf_entries=dduf_entries,
169+
)
170+
elif use_safetensors:
171+
try:
172+
resolved_model_file = _get_model_file(
173+
pretrained_model_name_or_path,
174+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
175+
cache_dir=cache_dir,
176+
force_download=force_download,
177+
proxies=proxies,
178+
local_files_only=local_files_only,
179+
token=token,
180+
revision=revision,
181+
subfolder=subfolder,
182+
user_agent=user_agent,
183+
commit_hash=commit_hash,
184+
dduf_entries=dduf_entries,
185+
)
186+
187+
except IOError as e:
188+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
189+
if not allow_pickle:
190+
raise
191+
logger.warning(
192+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
193+
)
194+
195+
if resolved_model_file is None and not is_sharded:
196+
resolved_model_file = _get_model_file(
197+
pretrained_model_name_or_path,
198+
weights_name=_add_variant(WEIGHTS_NAME, variant),
199+
cache_dir=cache_dir,
200+
force_download=force_download,
201+
proxies=proxies,
202+
local_files_only=local_files_only,
203+
token=token,
204+
revision=revision,
205+
subfolder=subfolder,
206+
user_agent=user_agent,
207+
commit_hash=commit_hash,
208+
dduf_entries=dduf_entries,
209+
)
210+
211+
if not isinstance(resolved_model_file, list):
212+
resolved_model_file = [resolved_model_file]
213+
214+
# set dtype to instantiate the model under:
215+
# 1. If mindspore_dtype is not None, we use that dtype
216+
# 2. If mindspore_dtype is float8, we don't use _set_default_mindspore_dtype and we downcast after loading the model
217+
dtype_orig = None # noqa
218+
if mindspore_dtype is not None:
219+
if not isinstance(mindspore_dtype, ms.Type):
220+
raise ValueError(
221+
f"{mindspore_dtype} needs to be of type `mindspore.Type`, e.g. `mindspore.float16`, but is {type(mindspore_dtype)}."
222+
)
223+
224+
with no_init_parameters():
225+
model = cls.from_config(config, **unused_kwargs)
226+
227+
# state_dict = None # state_dict may be passed as an additional key argument
228+
if state_dict is None: # edits: only load model_file if state_dict is None
229+
if not is_sharded:
230+
# Time to load the checkpoint
231+
state_dict = load_state_dict(
232+
resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries
233+
)
234+
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
235+
model._fix_state_dict_keys_on_load(state_dict)
236+
237+
if is_sharded:
238+
loaded_keys = sharded_metadata["all_checkpoint_keys"]
239+
else:
240+
state_dict = _convert_state_dict(model, state_dict)
241+
loaded_keys = list(state_dict.keys())
242+
243+
(
244+
model,
245+
missing_keys,
246+
unexpected_keys,
247+
mismatched_keys,
248+
offload_index,
249+
error_msgs,
250+
) = cls._load_pretrained_model(
251+
model,
252+
state_dict,
253+
resolved_model_file,
254+
pretrained_model_name_or_path,
255+
loaded_keys,
256+
ignore_mismatched_sizes=ignore_mismatched_sizes,
257+
dtype=mindspore_dtype,
258+
keep_in_fp32_modules=keep_in_fp32_modules,
259+
dduf_entries=dduf_entries,
260+
)
261+
loading_info = {
262+
"missing_keys": missing_keys,
263+
"unexpected_keys": unexpected_keys,
264+
"mismatched_keys": mismatched_keys,
265+
"error_msgs": error_msgs,
266+
}
267+
268+
if mindspore_dtype is not None and not use_keep_in_fp32_modules:
269+
model = model.to(mindspore_dtype)
270+
271+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
272+
273+
# Set model in evaluation mode to deactivate DropOut modules by default
274+
model.set_train(False)
275+
276+
if output_loading_info:
277+
return model, loading_info
278+
279+
return model

0 commit comments

Comments
 (0)