11# Adapted from
22# https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/opensora/models/causalvideovae/model/modeling_videobase.py
33
4+ import copy
5+ import logging
6+ import os
7+ from typing import Dict , Optional , Union
8+
9+ from huggingface_hub import DDUFEntry
10+ from huggingface_hub .utils import validate_hf_hub_args
11+
412import mindspore as ms
13+ from mindspore .nn .utils import no_init_parameters
514
6- from mindone .diffusers import ModelMixin
15+ from mindone .diffusers import ModelMixin , __version__
716from mindone .diffusers .configuration_utils import ConfigMixin
17+ from mindone .diffusers .models .model_loading_utils import _fetch_index_file , _fetch_index_file_legacy , load_state_dict
18+ from mindone .diffusers .models .modeling_utils import _convert_state_dict
19+ from mindone .diffusers .utils import (
20+ SAFETENSORS_WEIGHTS_NAME ,
21+ WEIGHTS_NAME ,
22+ _add_variant ,
23+ _get_checkpoint_shard_files ,
24+ _get_model_file ,
25+ )
26+
27+ logger = logging .getLogger (__name__ )
828
929
1030class VideoBaseAE (ModelMixin , ConfigMixin ):
@@ -23,3 +43,237 @@ def encode(self, x: ms.Tensor, *args, **kwargs):
2343
2444 def decode (self , encoding : ms .Tensor , * args , ** kwargs ):
2545 pass
46+
47+ @classmethod
48+ @validate_hf_hub_args
49+ def from_pretrained (cls , pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]], ** kwargs ):
50+ # adapted from mindone.diffusers.models.modeling_utils.from_pretrained
51+ state_dict = kwargs .pop ("state_dict" , None ) # additional key argument
52+ cache_dir = kwargs .pop ("cache_dir" , None )
53+ ignore_mismatched_sizes = kwargs .pop ("ignore_mismatched_sizes" , False )
54+ force_download = kwargs .pop ("force_download" , False )
55+ from_flax = kwargs .pop ("from_flax" , False )
56+ proxies = kwargs .pop ("proxies" , None )
57+ output_loading_info = kwargs .pop ("output_loading_info" , False )
58+ local_files_only = kwargs .pop ("local_files_only" , None )
59+ token = kwargs .pop ("token" , None )
60+ revision = kwargs .pop ("revision" , None )
61+ mindspore_dtype = kwargs .pop ("mindspore_dtype" , None )
62+ subfolder = kwargs .pop ("subfolder" , None )
63+ variant = kwargs .pop ("variant" , None )
64+ use_safetensors = kwargs .pop ("use_safetensors" , None )
65+ dduf_entries : Optional [Dict [str , DDUFEntry ]] = kwargs .pop ("dduf_entries" , None )
66+ disable_mmap = kwargs .pop ("disable_mmap" , False )
67+
68+ if mindspore_dtype is not None and not isinstance (mindspore_dtype , ms .Type ):
69+ mindspore_dtype = ms .float32
70+ logger .warning (
71+ f"Passed `mindspore_dtype` { mindspore_dtype } is not a `ms.Type`. Defaulting to `ms.float32`."
72+ )
73+
74+ allow_pickle = False
75+ if use_safetensors is None :
76+ use_safetensors = True
77+ allow_pickle = True
78+
79+ user_agent = {
80+ "diffusers" : __version__ ,
81+ "file_type" : "model" ,
82+ "framework" : "pytorch" ,
83+ }
84+ unused_kwargs = {}
85+
86+ # Load config if we don't provide a configuration
87+ config_path = pretrained_model_name_or_path
88+
89+ # load config
90+ config , unused_kwargs , commit_hash = cls .load_config (
91+ config_path ,
92+ cache_dir = cache_dir ,
93+ return_unused_kwargs = True ,
94+ return_commit_hash = True ,
95+ force_download = force_download ,
96+ proxies = proxies ,
97+ local_files_only = local_files_only ,
98+ token = token ,
99+ revision = revision ,
100+ subfolder = subfolder ,
101+ user_agent = user_agent ,
102+ dduf_entries = dduf_entries ,
103+ ** kwargs ,
104+ )
105+ # no in-place modification of the original config.
106+ config = copy .deepcopy (config )
107+
108+ # Check if `_keep_in_fp32_modules` is not None
109+ # use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and (
110+ # hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
111+ # )
112+ use_keep_in_fp32_modules = (cls ._keep_in_fp32_modules is not None ) and (mindspore_dtype == ms .float16 )
113+
114+ if use_keep_in_fp32_modules :
115+ keep_in_fp32_modules = cls ._keep_in_fp32_modules
116+ if not isinstance (keep_in_fp32_modules , list ):
117+ keep_in_fp32_modules = [keep_in_fp32_modules ]
118+ else :
119+ keep_in_fp32_modules = []
120+
121+ is_sharded = False
122+ resolved_model_file = None
123+
124+ # Determine if we're loading from a directory of sharded checkpoints.
125+ sharded_metadata = None
126+ index_file = None
127+ is_local = os .path .isdir (pretrained_model_name_or_path )
128+ index_file_kwargs = {
129+ "is_local" : is_local ,
130+ "pretrained_model_name_or_path" : pretrained_model_name_or_path ,
131+ "subfolder" : subfolder or "" ,
132+ "use_safetensors" : use_safetensors ,
133+ "cache_dir" : cache_dir ,
134+ "variant" : variant ,
135+ "force_download" : force_download ,
136+ "proxies" : proxies ,
137+ "local_files_only" : local_files_only ,
138+ "token" : token ,
139+ "revision" : revision ,
140+ "user_agent" : user_agent ,
141+ "commit_hash" : commit_hash ,
142+ "dduf_entries" : dduf_entries ,
143+ }
144+ index_file = _fetch_index_file (** index_file_kwargs )
145+ # In case the index file was not found we still have to consider the legacy format.
146+ # this becomes applicable when the variant is not None.
147+ if variant is not None and (index_file is None or not os .path .exists (index_file )):
148+ index_file = _fetch_index_file_legacy (** index_file_kwargs )
149+ if index_file is not None and (dduf_entries or index_file .is_file ()):
150+ is_sharded = True
151+
152+ # load model
153+ if from_flax :
154+ raise NotImplementedError ("loading flax checkpoint in mindspore model is not yet supported." )
155+ else :
156+ # in the case it is sharded, we have already the index
157+ if is_sharded :
158+ resolved_model_file , sharded_metadata = _get_checkpoint_shard_files (
159+ pretrained_model_name_or_path ,
160+ index_file ,
161+ cache_dir = cache_dir ,
162+ proxies = proxies ,
163+ local_files_only = local_files_only ,
164+ token = token ,
165+ user_agent = user_agent ,
166+ revision = revision ,
167+ subfolder = subfolder or "" ,
168+ dduf_entries = dduf_entries ,
169+ )
170+ elif use_safetensors :
171+ try :
172+ resolved_model_file = _get_model_file (
173+ pretrained_model_name_or_path ,
174+ weights_name = _add_variant (SAFETENSORS_WEIGHTS_NAME , variant ),
175+ cache_dir = cache_dir ,
176+ force_download = force_download ,
177+ proxies = proxies ,
178+ local_files_only = local_files_only ,
179+ token = token ,
180+ revision = revision ,
181+ subfolder = subfolder ,
182+ user_agent = user_agent ,
183+ commit_hash = commit_hash ,
184+ dduf_entries = dduf_entries ,
185+ )
186+
187+ except IOError as e :
188+ logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path } : { e } " )
189+ if not allow_pickle :
190+ raise
191+ logger .warning (
192+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
193+ )
194+
195+ if resolved_model_file is None and not is_sharded :
196+ resolved_model_file = _get_model_file (
197+ pretrained_model_name_or_path ,
198+ weights_name = _add_variant (WEIGHTS_NAME , variant ),
199+ cache_dir = cache_dir ,
200+ force_download = force_download ,
201+ proxies = proxies ,
202+ local_files_only = local_files_only ,
203+ token = token ,
204+ revision = revision ,
205+ subfolder = subfolder ,
206+ user_agent = user_agent ,
207+ commit_hash = commit_hash ,
208+ dduf_entries = dduf_entries ,
209+ )
210+
211+ if not isinstance (resolved_model_file , list ):
212+ resolved_model_file = [resolved_model_file ]
213+
214+ # set dtype to instantiate the model under:
215+ # 1. If mindspore_dtype is not None, we use that dtype
216+ # 2. If mindspore_dtype is float8, we don't use _set_default_mindspore_dtype and we downcast after loading the model
217+ dtype_orig = None # noqa
218+ if mindspore_dtype is not None :
219+ if not isinstance (mindspore_dtype , ms .Type ):
220+ raise ValueError (
221+ f"{ mindspore_dtype } needs to be of type `mindspore.Type`, e.g. `mindspore.float16`, but is { type (mindspore_dtype )} ."
222+ )
223+
224+ with no_init_parameters ():
225+ model = cls .from_config (config , ** unused_kwargs )
226+
227+ # state_dict = None # state_dict may be passed as an additional key argument
228+ if state_dict is None : # edits: only load model_file if state_dict is None
229+ if not is_sharded :
230+ # Time to load the checkpoint
231+ state_dict = load_state_dict (
232+ resolved_model_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries
233+ )
234+ # We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
235+ model ._fix_state_dict_keys_on_load (state_dict )
236+
237+ if is_sharded :
238+ loaded_keys = sharded_metadata ["all_checkpoint_keys" ]
239+ else :
240+ state_dict = _convert_state_dict (model , state_dict )
241+ loaded_keys = list (state_dict .keys ())
242+
243+ (
244+ model ,
245+ missing_keys ,
246+ unexpected_keys ,
247+ mismatched_keys ,
248+ offload_index ,
249+ error_msgs ,
250+ ) = cls ._load_pretrained_model (
251+ model ,
252+ state_dict ,
253+ resolved_model_file ,
254+ pretrained_model_name_or_path ,
255+ loaded_keys ,
256+ ignore_mismatched_sizes = ignore_mismatched_sizes ,
257+ dtype = mindspore_dtype ,
258+ keep_in_fp32_modules = keep_in_fp32_modules ,
259+ dduf_entries = dduf_entries ,
260+ )
261+ loading_info = {
262+ "missing_keys" : missing_keys ,
263+ "unexpected_keys" : unexpected_keys ,
264+ "mismatched_keys" : mismatched_keys ,
265+ "error_msgs" : error_msgs ,
266+ }
267+
268+ if mindspore_dtype is not None and not use_keep_in_fp32_modules :
269+ model = model .to (mindspore_dtype )
270+
271+ model .register_to_config (_name_or_path = pretrained_model_name_or_path )
272+
273+ # Set model in evaluation mode to deactivate DropOut modules by default
274+ model .set_train (False )
275+
276+ if output_loading_info :
277+ return model , loading_info
278+
279+ return model
0 commit comments