@@ -2454,6 +2454,18 @@ def from_pretrained(
24542454 if "attn_implementation" in kwargs :
24552455 config ._attn_implementation = kwargs .pop ("attn_implementation" )
24562456
2457+ transformers_explicit_filename = getattr (config , "transformers_weights" , None )
2458+
2459+ if transformers_explicit_filename is not None :
2460+ if not transformers_explicit_filename .endswith (
2461+ ".safetensors"
2462+ ) and not transformers_explicit_filename .endswith (".safetensors.index.json" ):
2463+ raise ValueError (
2464+ "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
2465+ "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
2466+ f"{ transformers_explicit_filename } "
2467+ )
2468+
24572469 # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
24582470 # index of the files.
24592471 is_sharded = False
@@ -2469,7 +2481,12 @@ def from_pretrained(
24692481 pretrained_model_name_or_path = str (pretrained_model_name_or_path )
24702482 is_local = os .path .isdir (pretrained_model_name_or_path )
24712483 if is_local :
2472- if from_tf and os .path .isfile (
2484+ if transformers_explicit_filename is not None :
2485+ # If the filename is explicitly defined, load this by default.
2486+ archive_file = os .path .join (pretrained_model_name_or_path , subfolder ,
2487+ transformers_explicit_filename )
2488+ is_sharded = transformers_explicit_filename .endswith (".safetensors.index.json" )
2489+ elif from_tf and os .path .isfile (
24732490 os .path .join (pretrained_model_name_or_path , subfolder , TF_WEIGHTS_NAME + ".index" )
24742491 ):
24752492 # Load from a TF 1.0 checkpoint in priority if from_tf
@@ -2558,7 +2575,10 @@ def from_pretrained(
25582575 resolved_archive_file = download_url (pretrained_model_name_or_path )
25592576 else :
25602577 # set correct filename
2561- if from_tf :
2578+ if transformers_explicit_filename is not None :
2579+ filename = transformers_explicit_filename
2580+ is_sharded = transformers_explicit_filename .endswith (".safetensors.index.json" )
2581+ elif from_tf :
25622582 filename = TF2_WEIGHTS_NAME
25632583 elif from_flax :
25642584 filename = FLAX_WEIGHTS_NAME
0 commit comments