diff --git a/src/ragas/prompt/mixin.py b/src/ragas/prompt/mixin.py index c354a8d9e..17db3b682 100644 --- a/src/ragas/prompt/mixin.py +++ b/src/ragas/prompt/mixin.py @@ -20,8 +20,9 @@ class PromptMixin: eg: [BaseSynthesizer][ragas.testset.synthesizers.base.BaseSynthesizer], [MetricWithLLM][ragas.metrics.base.MetricWithLLM] """ - def _get_prompts(self) -> t.Dict[str, PydanticPrompt]: + name: str = "" + def _get_prompts(self) -> t.Dict[str, PydanticPrompt]: prompts = {} for key, value in inspect.getmembers(self): if isinstance(value, PydanticPrompt): @@ -90,10 +91,13 @@ def save_prompts(self, path: str): prompts = self.get_prompts() for prompt_name, prompt in prompts.items(): # hash_hex = f"0x{hash(prompt) & 0xFFFFFFFFFFFFFFFF:016x}" - prompt_file_name = os.path.join( - path, f"{prompt_name}_{prompt.language}.json" - ) - prompt.save(prompt_file_name) + if self.name == "": + file_name = os.path.join(path, f"{prompt_name}_{prompt.language}.json") + else: + file_name = os.path.join( + path, f"{self.name}_{prompt_name}_{prompt.language}.json" + ) + prompt.save(file_name) def load_prompts(self, path: str, language: t.Optional[str] = None): """ @@ -113,7 +117,12 @@ def load_prompts(self, path: str, language: t.Optional[str] = None): loaded_prompts = {} for prompt_name, prompt in self.get_prompts().items(): - prompt_file_name = os.path.join(path, f"{prompt_name}_{language}.json") - loaded_prompt = prompt.__class__.load(prompt_file_name) + if self.name == "": + file_name = os.path.join(path, f"{prompt_name}_{language}.json") + else: + file_name = os.path.join( + path, f"{self.name}_{prompt_name}_{language}.json" + ) + loaded_prompt = prompt.__class__.load(file_name) loaded_prompts[prompt_name] = loaded_prompt return loaded_prompts