|
62 | 62 | from mindspore import Parameter, Tensor, mint, nn, ops |
63 | 63 | from mindspore.nn import CrossEntropyLoss, Identity |
64 | 64 | from mindspore.nn.utils import no_init_parameters |
65 | | -from mindspore.ops import Cast |
| 65 | + |
| 66 | +from mindone.utils.modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype |
66 | 67 |
|
67 | 68 | from .activations import get_activation |
68 | 69 | from .generation.utils import GenerationMixin |
|
81 | 82 | prune_linear_layer, |
82 | 83 | ) |
83 | 84 | from .modeling_attn_mask_utils import dtype_to_min |
84 | | -from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype |
85 | 85 | from .utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder |
86 | 86 | from .utils.import_utils import is_sdpa_available |
87 | 87 |
|
|
113 | 113 | ] |
114 | 114 |
|
115 | 115 | logger = logging.get_logger(__name__) |
116 | | -cpu_cast = Cast().set_device("CPU") |
| 116 | +ms.Parameter._data = ms.Tensor.data |
| 117 | +ms.Parameter.data_ptr = ms.Tensor.data_ptr |
117 | 118 |
|
118 | 119 | _init_weights = True |
119 | 120 |
|
@@ -377,7 +378,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar |
377 | 378 | local_state = {v.name: v for k, v in model_to_load.parameters_and_names()} |
378 | 379 | for k, v in state_dict.items(): |
379 | 380 | if k in local_state: |
380 | | - state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) |
| 381 | + v._data = v.to(device="CPU", dtype=local_state[k].dtype) |
381 | 382 | else: |
382 | 383 | pass # unexpect key keeps origin dtype |
383 | 384 | cm = silence_mindspore_logger() if is_sharded else nullcontext() |
@@ -514,17 +515,17 @@ def _get_name(self): |
514 | 515 | def to(self, dtype: Optional[ms.Type] = None): |
515 | 516 | for p in self.get_parameters(): |
516 | 517 | if p.dtype != dtype: |
517 | | - p.set_dtype(dtype) |
| 518 | + p._data = p.to(device="CPU", dtype=dtype) |
518 | 519 | return self |
519 | 520 |
|
520 | 521 | def float(self): |
521 | 522 | for p in self.get_parameters(): |
522 | | - p.set_dtype(ms.float32) |
| 523 | + p._data = p.to(device="CPU", dtype=ms.float32) |
523 | 524 | return self |
524 | 525 |
|
525 | 526 | def half(self): |
526 | 527 | for p in self.get_parameters(): |
527 | | - p.set_dtype(ms.float16) |
| 528 | + p._data = p.to(device="CPU", dtype=ms.float16) |
528 | 529 | return self |
529 | 530 |
|
530 | 531 | @property |
@@ -1162,12 +1163,12 @@ def _from_config(cls, config, **kwargs): |
1162 | 1163 | if "attn_implementation" in kwargs: |
1163 | 1164 | config._attn_implementation = kwargs.pop("attn_implementation") |
1164 | 1165 |
|
| 1166 | + if mindspore_dtype is not None: |
| 1167 | + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) |
1165 | 1168 | with no_init_parameters(): |
1166 | | - if mindspore_dtype is not None: |
1167 | | - patch_nn_default_dtype(dtype=mindspore_dtype, force=True) |
1168 | 1169 | model = cls(config, **kwargs) |
1169 | | - if mindspore_dtype is not None: |
1170 | | - restore_nn_default_dtype() |
| 1170 | + if mindspore_dtype is not None: |
| 1171 | + restore_nn_default_dtype() |
1171 | 1172 |
|
1172 | 1173 | # We cannot set default mindspore dtype. So we need to cast model weights after creating. |
1173 | 1174 | if mindspore_dtype is not None: |
@@ -2763,12 +2764,12 @@ def from_pretrained( |
2763 | 2764 |
|
2764 | 2765 | config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. |
2765 | 2766 |
|
| 2767 | + if mindspore_dtype is not None: |
| 2768 | + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) |
2766 | 2769 | with no_init_parameters(): |
2767 | | - if mindspore_dtype is not None: |
2768 | | - patch_nn_default_dtype(dtype=mindspore_dtype, force=True) |
2769 | 2770 | model = cls(config, *model_args, **model_kwargs) |
2770 | | - if mindspore_dtype is not None: |
2771 | | - restore_nn_default_dtype() |
| 2771 | + if mindspore_dtype is not None: |
| 2772 | + restore_nn_default_dtype() |
2772 | 2773 |
|
2773 | 2774 | # Make sure to tie the weights correctly |
2774 | 2775 | model.tie_weights() |
|
0 commit comments