diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 5d0752af8983..4878937ab202 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -154,12 +154,30 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None): """ Import a module on the cache directory for modules and extract a class from it. """ module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as e: + # This can happen when the repo id contains ".", which Python's import machinery interprets as a directory + # separator. We do a bit of monkey patching to detect and fix this case. + if not ( + pretrained_model_name_or_path is not None + and "." in pretrained_model_name_or_path + and module_path.startswith("diffusers_modules") + and pretrained_model_name_or_path.replace("/", "--") in module_path + ): + raise e # We can't figure this one out, just reraise the original error + + corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py" + corrected_path = corrected_path.replace( + pretrained_model_name_or_path.replace("/", "--").replace(".", "/"), + pretrained_model_name_or_path.replace("/", "--"), + ) + module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module() if class_name is None: return find_pipeline_class(module) @@ -454,4 +472,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module.replace(".py", "")) + return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index f1d9d244e546..65718a254595 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1105,6 +1105,21 @@ def test_remote_auto_custom_pipe(self): assert images.shape == (1, 64, 64, 3) + def test_remote_custom_pipe_with_dot_in_name(self): + # make sure that trust remote code has to be passed + with self.assertRaises(ValueError): + pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name") + + pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name", trust_remote_code=True) + + assert pipeline.__class__.__name__ == "CustomPipeline" + + pipeline = pipeline.to(torch_device) + images, output_str = pipeline(num_inference_steps=2, output_type="np") + + assert images[0].shape == (1, 32, 32, 3) + assert output_str == "This is a test" + def test_local_custom_pipeline_repo(self): local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") pipeline = DiffusionPipeline.from_pretrained(