Skip to content

Commit d54622c

Browse files
DN6yiyixuxu
andauthored
[Modular] Allow custom blocks to be saved to local_dir (#12381)
update Co-authored-by: YiYi Xu <[email protected]>
1 parent df8dd77 commit d54622c

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def from_pretrained(
305305
"cache_dir",
306306
"force_download",
307307
"local_files_only",
308+
"local_dir",
308309
"proxies",
309310
"resume_download",
310311
"revision",
@@ -331,7 +332,6 @@ def from_pretrained(
331332
module_file=module_file,
332333
class_name=class_name,
333334
**hub_kwargs,
334-
**kwargs,
335335
)
336336
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
337337
block_kwargs = {

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def get_cached_module_file(
254254
token: Optional[Union[bool, str]] = None,
255255
revision: Optional[str] = None,
256256
local_files_only: bool = False,
257+
local_dir: Optional[str] = None,
257258
):
258259
"""
259260
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -332,6 +333,7 @@ def get_cached_module_file(
332333
force_download=force_download,
333334
proxies=proxies,
334335
local_files_only=local_files_only,
336+
local_dir=local_dir,
335337
)
336338
submodule = "git"
337339
module_file = pretrained_model_name_or_path + ".py"
@@ -355,6 +357,7 @@ def get_cached_module_file(
355357
force_download=force_download,
356358
proxies=proxies,
357359
local_files_only=local_files_only,
360+
local_dir=local_dir,
358361
token=token,
359362
)
360363
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -415,6 +418,7 @@ def get_cached_module_file(
415418
token=token,
416419
revision=revision,
417420
local_files_only=local_files_only,
421+
local_dir=local_dir,
418422
)
419423
return os.path.join(full_submodule, module_file)
420424

@@ -431,7 +435,7 @@ def get_class_from_dynamic_module(
431435
token: Optional[Union[bool, str]] = None,
432436
revision: Optional[str] = None,
433437
local_files_only: bool = False,
434-
**kwargs,
438+
local_dir: Optional[str] = None,
435439
):
436440
"""
437441
Extracts a class from a module file, present in the local folder or repository of a model.
@@ -496,5 +500,6 @@ def get_class_from_dynamic_module(
496500
token=token,
497501
revision=revision,
498502
local_files_only=local_files_only,
503+
local_dir=local_dir,
499504
)
500505
return get_class_in_module(class_name, final_module)

0 commit comments

Comments
 (0)