Skip to content

Commit acb4f23

Browse files
committed
fix(transformers): support loading weight from explicit transformers weight
1 parent 8f6e9c9 commit acb4f23

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

mindone/transformers/modeling_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,6 +2454,18 @@ def from_pretrained(
24542454
if "attn_implementation" in kwargs:
24552455
config._attn_implementation = kwargs.pop("attn_implementation")
24562456

2457+
transformers_explicit_filename = getattr(config, "transformers_weights", None)
2458+
2459+
if transformers_explicit_filename is not None:
2460+
if not transformers_explicit_filename.endswith(
2461+
".safetensors"
2462+
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
2463+
raise ValueError(
2464+
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
2465+
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
2466+
f"{transformers_explicit_filename}"
2467+
)
2468+
24572469
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
24582470
# index of the files.
24592471
is_sharded = False
@@ -2469,7 +2481,12 @@ def from_pretrained(
24692481
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
24702482
is_local = os.path.isdir(pretrained_model_name_or_path)
24712483
if is_local:
2472-
if from_tf and os.path.isfile(
2484+
if transformers_explicit_filename is not None:
2485+
# If the filename is explicitly defined, load this by default.
2486+
archive_file = os.path.join(pretrained_model_name_or_path, subfolder,
2487+
transformers_explicit_filename)
2488+
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
2489+
elif from_tf and os.path.isfile(
24732490
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
24742491
):
24752492
# Load from a TF 1.0 checkpoint in priority if from_tf
@@ -2558,7 +2575,10 @@ def from_pretrained(
25582575
resolved_archive_file = download_url(pretrained_model_name_or_path)
25592576
else:
25602577
# set correct filename
2561-
if from_tf:
2578+
if transformers_explicit_filename is not None:
2579+
filename = transformers_explicit_filename
2580+
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
2581+
elif from_tf:
25622582
filename = TF2_WEIGHTS_NAME
25632583
elif from_flax:
25642584
filename = FLAX_WEIGHTS_NAME

0 commit comments

Comments
 (0)