@@ -254,6 +254,7 @@ def get_cached_module_file(
254254 token : Optional [Union [bool , str ]] = None ,
255255 revision : Optional [str ] = None ,
256256 local_files_only : bool = False ,
257+ local_dir : Optional [str ] = None ,
257258):
258259 """
259260 Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -332,6 +333,7 @@ def get_cached_module_file(
332333 force_download = force_download ,
333334 proxies = proxies ,
334335 local_files_only = local_files_only ,
336+ local_dir = local_dir ,
335337 )
336338 submodule = "git"
337339 module_file = pretrained_model_name_or_path + ".py"
@@ -355,6 +357,7 @@ def get_cached_module_file(
355357 force_download = force_download ,
356358 proxies = proxies ,
357359 local_files_only = local_files_only ,
360+ local_dir = local_dir ,
358361 token = token ,
359362 )
360363 submodule = os .path .join ("local" , "--" .join (pretrained_model_name_or_path .split ("/" )))
@@ -415,6 +418,7 @@ def get_cached_module_file(
415418 token = token ,
416419 revision = revision ,
417420 local_files_only = local_files_only ,
421+ local_dir = local_dir ,
418422 )
419423 return os .path .join (full_submodule , module_file )
420424
@@ -431,7 +435,7 @@ def get_class_from_dynamic_module(
431435 token : Optional [Union [bool , str ]] = None ,
432436 revision : Optional [str ] = None ,
433437 local_files_only : bool = False ,
434- ** kwargs ,
438+ local_dir : Optional [ str ] = None ,
435439):
436440 """
437441 Extracts a class from a module file, present in the local folder or repository of a model.
@@ -496,5 +500,6 @@ def get_class_from_dynamic_module(
496500 token = token ,
497501 revision = revision ,
498502 local_files_only = local_files_only ,
503+ local_dir = local_dir ,
499504 )
500505 return get_class_in_module (class_name , final_module )
0 commit comments