Skip to content

Commit f5ea810

Browse files
committed
a new method
1 parent 3a1ed35 commit f5ea810

File tree

6 files changed

+34
-78
lines changed

6 files changed

+34
-78
lines changed

mindone/diffusers/models/model_loading_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
import mindspore as ms
3737
from mindspore import nn, ops
38-
from mindspore.ops import Cast
3938

4039
from ...safetensors.mindspore import load as safe_load
4140
from ..utils import (
@@ -51,7 +50,8 @@
5150
)
5251

5352
logger = logging.get_logger(__name__)
54-
cpu_cast = Cast().set_device("CPU")
53+
ms.Parameter._data = ms.Tensor.data
54+
ms.Parameter.data_ptr = ms.Tensor.data_ptr
5555

5656
_CLASS_REMAPPING_DICT = {
5757
"Transformer2DModel": {
@@ -146,11 +146,11 @@ def _load_state_dict_into_model(
146146
if keep_in_fp32_modules is not None and any(
147147
module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
148148
):
149-
state_dict[k] = ms.Parameter(cpu_cast(v.data, ms.float32), name=k)
149+
v._data = v.to(device="CPU", dtype=ms.float32)
150150
else:
151-
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
151+
v._data = v.to(device="CPU", dtype=local_state[k].dtype)
152152
else:
153-
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
153+
v._data = v.to(device="CPU", dtype=local_state[k].dtype)
154154
else:
155155
pass # unexpect key keeps origin dtype
156156
cm = silence_mindspore_logger() if is_sharded else nullcontext()

mindone/diffusers/models/modeling_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from mindspore.nn.utils import no_init_parameters
3838

3939
from mindone.safetensors.mindspore import save_file as safe_save_file
40+
from mindone.utils.modeling_patch import patch_nn_default_dtype, unpatch_nn_default_dtype
4041

4142
from .. import __version__
4243
from ..utils import (
@@ -61,7 +62,9 @@
6162
load_state_dict,
6263
split_torch_state_dict_into_shards,
6364
)
64-
from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype
65+
66+
ms.Parameter._data = ms.Tensor.data
67+
ms.Parameter.data_ptr = ms.Tensor.data_ptr
6568

6669

6770
class ContextManagers:
@@ -853,12 +856,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
853856
f"{mindspore_dtype} needs to be of type `mindspore.Type`, e.g. `mindspore.float16`, but is {type(mindspore_dtype)}."
854857
)
855858

859+
if mindspore_dtype is not None:
860+
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
856861
with no_init_parameters():
857-
if mindspore_dtype is not None:
858-
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
859862
model = cls.from_config(config, **unused_kwargs)
860-
if mindspore_dtype is not None:
861-
restore_nn_default_dtype()
863+
if mindspore_dtype is not None:
864+
unpatch_nn_default_dtype()
862865

863866
state_dict = None
864867
if not is_sharded:
@@ -915,17 +918,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
915918
def to(self, dtype: Optional[ms.Type] = None):
916919
for p in self.get_parameters():
917920
if p.dtype != dtype:
918-
p.set_dtype(dtype)
921+
p._data = p.to(device="CPU", dtype=dtype)
919922
return self
920923

921924
def half(self):
922925
for p in self.get_parameters():
923-
p.set_dtype(ms.float16)
926+
p._data = p.to(device="CPU", dtype=ms.float16)
924927
return self
925928

926929
def float(self):
927930
for p in self.get_parameters():
928-
p.set_dtype(ms.float32)
931+
p._data = p.to(device="CPU", dtype=ms.float32)
929932
return self
930933

931934
def compile_repeated_blocks(self, *args, **kwargs):

mindone/transformers/modeling_patch.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

mindone/transformers/modeling_utils.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@
6262
from mindspore import Parameter, Tensor, mint, nn, ops
6363
from mindspore.nn import CrossEntropyLoss, Identity
6464
from mindspore.nn.utils import no_init_parameters
65-
from mindspore.ops import Cast
65+
66+
from mindone.utils.modeling_patch import patch_nn_default_dtype, unpatch_nn_default_dtype
6667

6768
from .activations import get_activation
6869
from .generation.utils import GenerationMixin
@@ -81,7 +82,6 @@
8182
prune_linear_layer,
8283
)
8384
from .modeling_attn_mask_utils import dtype_to_min
84-
from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype
8585
from .utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder
8686
from .utils.import_utils import is_sdpa_available
8787

@@ -113,7 +113,8 @@
113113
]
114114

115115
logger = logging.get_logger(__name__)
116-
cpu_cast = Cast().set_device("CPU")
116+
ms.Parameter._data = ms.Tensor.data
117+
ms.Parameter.data_ptr = ms.Tensor.data_ptr
117118

118119
_init_weights = True
119120

@@ -377,7 +378,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
377378
local_state = {v.name: v for k, v in model_to_load.parameters_and_names()}
378379
for k, v in state_dict.items():
379380
if k in local_state:
380-
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
381+
v._data = v.to(device="CPU", dtype=local_state[k].dtype)
381382
else:
382383
pass # unexpect key keeps origin dtype
383384
cm = silence_mindspore_logger() if is_sharded else nullcontext()
@@ -514,17 +515,17 @@ def _get_name(self):
514515
def to(self, dtype: Optional[ms.Type] = None):
515516
for p in self.get_parameters():
516517
if p.dtype != dtype:
517-
p.set_dtype(dtype)
518+
p._data = p.to(device="CPU", dtype=dtype)
518519
return self
519520

520521
def float(self):
521522
for p in self.get_parameters():
522-
p.set_dtype(ms.float32)
523+
p._data = p.to(device="CPU", dtype=ms.float32)
523524
return self
524525

525526
def half(self):
526527
for p in self.get_parameters():
527-
p.set_dtype(ms.float16)
528+
p._data = p.to(device="CPU", dtype=ms.float16)
528529
return self
529530

530531
@property
@@ -1162,12 +1163,12 @@ def _from_config(cls, config, **kwargs):
11621163
if "attn_implementation" in kwargs:
11631164
config._attn_implementation = kwargs.pop("attn_implementation")
11641165

1166+
if mindspore_dtype is not None:
1167+
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
11651168
with no_init_parameters():
1166-
if mindspore_dtype is not None:
1167-
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
11681169
model = cls(config, **kwargs)
1169-
if mindspore_dtype is not None:
1170-
restore_nn_default_dtype()
1170+
if mindspore_dtype is not None:
1171+
unpatch_nn_default_dtype()
11711172

11721173
# We cannot set default mindspore dtype. So we need to cast model weights after creating.
11731174
if mindspore_dtype is not None:
@@ -2763,12 +2764,12 @@ def from_pretrained(
27632764

27642765
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
27652766

2767+
if mindspore_dtype is not None:
2768+
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
27662769
with no_init_parameters():
2767-
if mindspore_dtype is not None:
2768-
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
27692770
model = cls(config, *model_args, **model_kwargs)
2770-
if mindspore_dtype is not None:
2771-
restore_nn_default_dtype()
2771+
if mindspore_dtype is not None:
2772+
unpatch_nn_default_dtype()
27722773

27732774
# Make sure to tie the weights correctly
27742775
model.tie_weights()

mindone/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .env import init_env
22
from .logger import set_logger
3+
from .modeling_patch import patch_nn_default_dtype, unpatch_nn_default_dtype
34
from .params import count_params
45
from .weight_norm import WeightNorm

mindone/diffusers/models/modeling_patch.py renamed to mindone/utils/modeling_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _new_init(self, *args, _orig_init=_orig_init, **kwargs):
4040
setattr(attr, "__init__", _new_init)
4141

4242

43-
def restore_nn_default_dtype():
43+
def unpatch_nn_default_dtype():
4444
"""
4545
Manually restore the original __init__ of all patched nn / mint.nn Cells.
4646
"""

0 commit comments

Comments
 (0)