From d6b9e754ff75fcd90560900f9f887eacaae52f48 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Thu, 10 Oct 2024 17:01:27 +0530 Subject: [PATCH] feat: save and load Prompts (#1458) --- src/ragas/prompt/base.py | 14 ++- src/ragas/prompt/mixin.py | 46 ++++++++ src/ragas/prompt/pydantic_prompt.py | 99 +++++++++++++++++ .../testset/synthesizers/abstract_query.py | 4 +- src/ragas/testset/synthesizers/prompts.py | 2 +- tests/conftest.py | 32 +++--- tests/unit/prompt/test_prompt_mixin.py | 48 +++++++++ tests/unit/test_prompt.py | 102 +++++++++++++++++- 8 files changed, 324 insertions(+), 23 deletions(-) create mode 100644 tests/unit/prompt/test_prompt_mixin.py diff --git a/src/ragas/prompt/base.py b/src/ragas/prompt/base.py index bc2cad7df..e2f787219 100644 --- a/src/ragas/prompt/base.py +++ b/src/ragas/prompt/base.py @@ -25,12 +25,18 @@ def _check_if_language_is_supported(language: str): class BasePrompt(ABC): - def __init__(self, name: t.Optional[str] = None, language: str = "english"): + def __init__( + self, + name: t.Optional[str] = None, + language: str = "english", + original_hash: t.Optional[str] = None, + ): if name is None: self.name = camel_to_snake(self.__class__.__name__) _check_if_language_is_supported(language) self.language = language + self.original_hash = original_hash @abstractmethod async def generate( @@ -65,10 +71,16 @@ def generate_multiple( class StringIO(BaseModel): text: str + def __hash__(self): + return hash(self.text) + class BoolIO(BaseModel): value: bool + def __hash__(self): + return hash(self.value) + class StringPrompt(BasePrompt): """ diff --git a/src/ragas/prompt/mixin.py b/src/ragas/prompt/mixin.py index 14191cf00..3924735c9 100644 --- a/src/ragas/prompt/mixin.py +++ b/src/ragas/prompt/mixin.py @@ -1,14 +1,20 @@ from __future__ import annotations import inspect +import logging +import os import typing as t +from .base import _check_if_language_is_supported from .pydantic_prompt import PydanticPrompt if t.TYPE_CHECKING: from ragas.llms.base import BaseRagasLLM +logger = logging.getLogger(__name__) + + class PromptMixin: def get_prompts(self) -> t.Dict[str, PydanticPrompt]: prompts = {} @@ -40,3 +46,43 @@ async def adapt_prompts( adapted_prompts[name] = adapted_prompt return adapted_prompts + + def save_prompts(self, path: str): + """ + save prompts to a directory in the format of {name}_{language}.json + """ + # check if path is valid + if not os.path.exists(path): + raise ValueError(f"Path {path} does not exist") + + prompts = self.get_prompts() + for prompt_name, prompt in prompts.items(): + # hash_hex = f"0x{hash(prompt) & 0xFFFFFFFFFFFFFFFF:016x}" + prompt_file_name = os.path.join( + path, f"{prompt_name}_{prompt.language}.json" + ) + prompt.save(prompt_file_name) + + def load_prompts(self, path: str, language: t.Optional[str] = None): + """ + Load prompts from a directory in the format of {name}_{language}.json + """ + # check if path is valid + if not os.path.exists(path): + raise ValueError(f"Path {path} does not exist") + + # check if language is supported, defaults to english + if language is None: + language = "english" + logger.info( + "Language not specified, loading prompts for default language: %s", + language, + ) + _check_if_language_is_supported(language) + + loaded_prompts = {} + for prompt_name, prompt in self.get_prompts().items(): + prompt_file_name = os.path.join(path, f"{prompt_name}_{language}.json") + loaded_prompt = prompt.__class__.load(prompt_file_name) + loaded_prompts[prompt_name] = loaded_prompt + return loaded_prompts diff --git a/src/ragas/prompt/pydantic_prompt.py b/src/ragas/prompt/pydantic_prompt.py index 91b1a1e78..486542a24 100644 --- a/src/ragas/prompt/pydantic_prompt.py +++ b/src/ragas/prompt/pydantic_prompt.py @@ -1,13 +1,16 @@ from __future__ import annotations import copy +import json import logging +import os import typing as t from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import PydanticOutputParser from pydantic import BaseModel +from ragas._version import __version__ from ragas.callbacks import new_group from ragas.exceptions import RagasOutputParserException from ragas.llms.prompt import PromptValue @@ -220,6 +223,11 @@ async def adapt( # throws ValueError if language is not supported _check_if_language_is_supported(target_language) + # set the original hash, this is used to + # identify the original prompt object when loading from file + if self.original_hash is None: + self.original_hash = hash(self) + strings = get_all_strings(self.examples) translated_strings = await translate_statements_prompt.generate( llm=llm, @@ -237,6 +245,97 @@ async def adapt( new_prompt.language = target_language return new_prompt + def __hash__(self): + # convert examples to json string for hashing + examples = [] + for example in self.examples: + input_model, output_model = example + examples.append( + (input_model.model_dump_json(), output_model.model_dump_json()) + ) + + # not sure if input_model and output_model should be included + return hash( + ( + self.name, + self.input_model, + self.output_model, + self.instruction, + *examples, + self.language, + ) + ) + + def __eq__(self, other): + if not isinstance(other, PydanticPrompt): + return False + return ( + self.name == other.name + and self.input_model == other.input_model + and self.output_model == other.output_model + and self.instruction == other.instruction + and self.examples == other.examples + and self.language == other.language + ) + + def save(self, file_path: str): + """ + Save the prompt to a file. + """ + data = { + "ragas_version": __version__, + "original_hash": ( + hash(self) if self.original_hash is None else self.original_hash + ), + "language": self.language, + "instruction": self.instruction, + "examples": [ + {"input": example[0].model_dump(), "output": example[1].model_dump()} + for example in self.examples + ], + } + if os.path.exists(file_path): + raise FileExistsError(f"The file '{file_path}' already exists.") + with open(file_path, "w") as f: + json.dump(data, f, indent=2) + print(f"Prompt saved to {file_path}") + + @classmethod + def load(cls, file_path: str) -> "PydanticPrompt[InputModel, OutputModel]": + with open(file_path, "r") as f: + data = json.load(f) + + # You might want to add version compatibility checks here + ragas_version = data.get("ragas_version") + if ragas_version != __version__: + logger.warning( + "Prompt was saved with Ragas v%s, but you are loading it with Ragas v%s. " + "There might be incompatibilities.", + ragas_version, + __version__, + ) + original_hash = data.get("original_hash") + + prompt = cls() + instruction = data["instruction"] + examples = [ + ( + prompt.input_model(**example["input"]), + prompt.output_model(**example["output"]), + ) + for example in data["examples"] + ] + + prompt.instruction = instruction + prompt.examples = examples + prompt.language = data.get("language", prompt.language) + + # Optionally, verify the loaded prompt's hash matches the saved hash + if original_hash is not None and hash(prompt) != original_hash: + logger.warning("Loaded prompt hash does not match the saved hash.") + + return prompt + # Ragas Output Parser class OutputStringAndPrompt(BaseModel): diff --git a/src/ragas/testset/synthesizers/abstract_query.py b/src/ragas/testset/synthesizers/abstract_query.py index 2e547f8ce..f2d0d51be 100644 --- a/src/ragas/testset/synthesizers/abstract_query.py +++ b/src/ragas/testset/synthesizers/abstract_query.py @@ -17,7 +17,7 @@ AbstractQueryFromTheme, CAQInput, CommonConceptsFromKeyphrases, - CommonThemeFromSummaries, + CommonThemeFromSummariesPrompt, ComparativeAbstractQuery, Concepts, KeyphrasesAndNumConcepts, @@ -44,7 +44,7 @@ class AbstractQuerySynthesizer(QuerySynthesizer): def __post_init__(self): super().__post_init__() - self.common_theme_prompt = CommonThemeFromSummaries() + self.common_theme_prompt = CommonThemeFromSummariesPrompt() async def _generate_scenarios( self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks diff --git a/src/ragas/testset/synthesizers/prompts.py b/src/ragas/testset/synthesizers/prompts.py index c3ebc2eaa..f8150c007 100644 --- a/src/ragas/testset/synthesizers/prompts.py +++ b/src/ragas/testset/synthesizers/prompts.py @@ -20,7 +20,7 @@ class Themes(BaseModel): themes: t.List[Theme] -class CommonThemeFromSummaries(PydanticPrompt[Summaries, Themes]): +class CommonThemeFromSummariesPrompt(PydanticPrompt[Summaries, Themes]): input_model = Summaries output_model = Themes instruction = "Analyze the following summaries and identify given number of common themes. The themes should be concise, descriptive, and highlight a key aspect shared across the summaries." diff --git a/tests/conftest.py b/tests/conftest.py index 5dcb9194b..6e5136360 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,22 +28,24 @@ def pytest_configure(config): ) -class FakeTestLLM(BaseRagasLLM): - def llm(self): - return self - - def generate_text( - self, prompt: PromptValue, n=1, temperature=1e-8, stop=None, callbacks=[] - ): - generations = [[Generation(text=prompt.prompt_str)] * n] - return LLMResult(generations=generations) - - async def agenerate_text( - self, prompt: PromptValue, n=1, temperature=1e-8, stop=None, callbacks=[] - ): - return self.generate_text(prompt, n, temperature, stop, callbacks) +class EchoLLM(BaseRagasLLM): + def generate_text( # type: ignore + self, + prompt: PromptValue, + *args, + **kwargs, + ) -> LLMResult: + return LLMResult(generations=[[Generation(text=prompt.to_string())]]) + + async def agenerate_text( # type: ignore + self, + prompt: PromptValue, + *args, + **kwargs, + ) -> LLMResult: + return LLMResult(generations=[[Generation(text=prompt.to_string())]]) @pytest.fixture def fake_llm(): - return FakeTestLLM() + return EchoLLM() diff --git a/tests/unit/prompt/test_prompt_mixin.py b/tests/unit/prompt/test_prompt_mixin.py new file mode 100644 index 000000000..c990e16a4 --- /dev/null +++ b/tests/unit/prompt/test_prompt_mixin.py @@ -0,0 +1,48 @@ +import pytest + +from ragas.testset.synthesizers import AbstractQuerySynthesizer + + +def test_prompt_save_load(tmp_path, fake_llm): + synth = AbstractQuerySynthesizer(llm=fake_llm) + synth_prompts = synth.get_prompts() + synth.save_prompts(tmp_path) + loaded_prompts = synth.load_prompts(tmp_path) + assert len(synth_prompts) == len(loaded_prompts) + for name, prompt in synth_prompts.items(): + assert name in loaded_prompts + assert prompt == loaded_prompts[name] + + +@pytest.mark.asyncio +async def test_prompt_save_adapt_load(tmp_path, fake_llm): + synth = AbstractQuerySynthesizer(llm=fake_llm) + + # patch adapt_prompts + async def adapt_prompts_patched(self, language, llm): + for prompt in self.get_prompts().values(): + prompt.instruction = "test" + prompt.language = language + return self.get_prompts() + + synth.adapt_prompts = adapt_prompts_patched.__get__(synth) + + # adapt prompts + original_prompts = synth.get_prompts() + adapted_prompts = await synth.adapt_prompts("spanish", fake_llm) + synth.set_prompts(**adapted_prompts) + + # save n load + synth.save_prompts(tmp_path) + loaded_prompts = synth.load_prompts(tmp_path, language="spanish") + + # check conditions + assert len(adapted_prompts) == len(loaded_prompts) + for name, adapted_prompt in adapted_prompts.items(): + assert name in loaded_prompts + assert name in original_prompts + + loaded_prompt = loaded_prompts[name] + assert adapted_prompt.instruction == loaded_prompt.instruction + assert adapted_prompt.language == loaded_prompt.language + assert adapted_prompt == loaded_prompt diff --git a/tests/unit/test_prompt.py b/tests/unit/test_prompt.py index c431262a8..d51203c46 100644 --- a/tests/unit/test_prompt.py +++ b/tests/unit/test_prompt.py @@ -1,3 +1,5 @@ +import copy + import pytest from langchain_core.outputs import Generation, LLMResult @@ -100,12 +102,104 @@ class Prompt(PydanticPrompt[StringIO, StringIO]): def test_prompt_hash(): - from ragas.prompt import StringPrompt + from ragas.prompt import PydanticPrompt, StringIO - class Prompt(StringPrompt): + class Prompt(PydanticPrompt[StringIO, StringIO]): instruction = "You are a helpful assistant." + input_model = StringIO + output_model = StringIO p = Prompt() - assert hash(p) == hash(p) + p_copy = Prompt() + assert hash(p) == hash(p_copy) + assert p == p_copy p.instruction = "You are a helpful assistant. And some more" - # assert hash(p) != hash(p) + assert hash(p) != hash(p_copy) + assert p != p_copy + + +def test_prompt_hash_in_ragas(fake_llm): + # check with a prompt inside ragas + from ragas.testset.synthesizers import AbstractQuerySynthesizer + + synthesizer = AbstractQuerySynthesizer(llm=fake_llm) + prompts = synthesizer.get_prompts() + for prompt in prompts.values(): + assert hash(prompt) == hash(prompt) + assert prompt == prompt + + # change instruction and check if hash changes + for prompt in prompts.values(): + old_prompt = copy.deepcopy(prompt) + prompt.instruction = "You are a helpful assistant." + assert hash(prompt) != hash(old_prompt) + assert prompt != old_prompt + + +def test_prompt_save_load(tmp_path): + from ragas.prompt import PydanticPrompt, StringIO + + class Prompt(PydanticPrompt[StringIO, StringIO]): + instruction = "You are a helpful assistant." + input_model = StringIO + output_model = StringIO + examples = [ + (StringIO(text="hello"), StringIO(text="hello")), + (StringIO(text="world"), StringIO(text="world")), + ] + + p = Prompt() + file_path = tmp_path / "test_prompt.json" + p.save(file_path) + p1 = Prompt.load(file_path) + assert hash(p) == hash(p1) + assert p == p1 + + +def test_prompt_save_load_language(tmp_path): + from ragas.prompt import PydanticPrompt, StringIO + + class Prompt(PydanticPrompt[StringIO, StringIO]): + instruction = "You are a helpful assistant." + language = "spanish" + input_model = StringIO + output_model = StringIO + examples = [ + (StringIO(text="hello"), StringIO(text="hello")), + (StringIO(text="world"), StringIO(text="world")), + ] + + p_spanish = Prompt() + file_path = tmp_path / "test_prompt_spanish.json" + p_spanish.save(file_path) + p_spanish_loaded = Prompt.load(file_path) + assert hash(p_spanish) == hash(p_spanish_loaded) + assert p_spanish == p_spanish_loaded + + +def test_save_existing_prompt(tmp_path): + from ragas.testset.synthesizers.prompts import CommonThemeFromSummariesPrompt + + p = CommonThemeFromSummariesPrompt() + file_path = tmp_path / "test_prompt.json" + p.save(file_path) + p2 = CommonThemeFromSummariesPrompt.load(file_path) + assert p == p2 + + +def test_prompt_class_attributes(): + """ + We are using class attributes to store the prompt instruction and examples. + We want to make sure there is no relationship between the class attributes + and instance. + """ + from ragas.testset.synthesizers.prompts import CommonThemeFromSummariesPrompt + + p = CommonThemeFromSummariesPrompt() + p_another_instance = CommonThemeFromSummariesPrompt() + assert p.instruction == p_another_instance.instruction + assert p.examples == p_another_instance.examples + p.instruction = "You are a helpful assistant." + p.examples = [] + assert p.instruction != p_another_instance.instruction + assert p.examples != p_another_instance.examples