Skip to content

Commit

Permalink
feat: save and load Prompts (#1458)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan authored Oct 10, 2024
1 parent 75fa77b commit d6b9e75
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 23 deletions.
14 changes: 13 additions & 1 deletion src/ragas/prompt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@ def _check_if_language_is_supported(language: str):


class BasePrompt(ABC):
def __init__(self, name: t.Optional[str] = None, language: str = "english"):
def __init__(
self,
name: t.Optional[str] = None,
language: str = "english",
original_hash: t.Optional[str] = None,
):
if name is None:
self.name = camel_to_snake(self.__class__.__name__)

_check_if_language_is_supported(language)
self.language = language
self.original_hash = original_hash

@abstractmethod
async def generate(
Expand Down Expand Up @@ -65,10 +71,16 @@ def generate_multiple(
class StringIO(BaseModel):
text: str

def __hash__(self):
return hash(self.text)


class BoolIO(BaseModel):
value: bool

def __hash__(self):
return hash(self.value)


class StringPrompt(BasePrompt):
"""
Expand Down
46 changes: 46 additions & 0 deletions src/ragas/prompt/mixin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

import inspect
import logging
import os
import typing as t

from .base import _check_if_language_is_supported
from .pydantic_prompt import PydanticPrompt

if t.TYPE_CHECKING:
from ragas.llms.base import BaseRagasLLM


logger = logging.getLogger(__name__)


class PromptMixin:
def get_prompts(self) -> t.Dict[str, PydanticPrompt]:
prompts = {}
Expand Down Expand Up @@ -40,3 +46,43 @@ async def adapt_prompts(
adapted_prompts[name] = adapted_prompt

return adapted_prompts

def save_prompts(self, path: str):
"""
save prompts to a directory in the format of {name}_{language}.json
"""
# check if path is valid
if not os.path.exists(path):
raise ValueError(f"Path {path} does not exist")

prompts = self.get_prompts()
for prompt_name, prompt in prompts.items():
# hash_hex = f"0x{hash(prompt) & 0xFFFFFFFFFFFFFFFF:016x}"
prompt_file_name = os.path.join(
path, f"{prompt_name}_{prompt.language}.json"
)
prompt.save(prompt_file_name)

def load_prompts(self, path: str, language: t.Optional[str] = None):
"""
Load prompts from a directory in the format of {name}_{language}.json
"""
# check if path is valid
if not os.path.exists(path):
raise ValueError(f"Path {path} does not exist")

# check if language is supported, defaults to english
if language is None:
language = "english"
logger.info(
"Language not specified, loading prompts for default language: %s",
language,
)
_check_if_language_is_supported(language)

loaded_prompts = {}
for prompt_name, prompt in self.get_prompts().items():
prompt_file_name = os.path.join(path, f"{prompt_name}_{language}.json")
loaded_prompt = prompt.__class__.load(prompt_file_name)
loaded_prompts[prompt_name] = loaded_prompt
return loaded_prompts
99 changes: 99 additions & 0 deletions src/ragas/prompt/pydantic_prompt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import copy
import json
import logging
import os
import typing as t

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel

from ragas._version import __version__
from ragas.callbacks import new_group
from ragas.exceptions import RagasOutputParserException
from ragas.llms.prompt import PromptValue
Expand Down Expand Up @@ -220,6 +223,11 @@ async def adapt(
# throws ValueError if language is not supported
_check_if_language_is_supported(target_language)

# set the original hash, this is used to
# identify the original prompt object when loading from file
if self.original_hash is None:
self.original_hash = hash(self)

strings = get_all_strings(self.examples)
translated_strings = await translate_statements_prompt.generate(
llm=llm,
Expand All @@ -237,6 +245,97 @@ async def adapt(
new_prompt.language = target_language
return new_prompt

def __hash__(self):
# convert examples to json string for hashing
examples = []
for example in self.examples:
input_model, output_model = example
examples.append(
(input_model.model_dump_json(), output_model.model_dump_json())
)

# not sure if input_model and output_model should be included
return hash(
(
self.name,
self.input_model,
self.output_model,
self.instruction,
*examples,
self.language,
)
)

def __eq__(self, other):
if not isinstance(other, PydanticPrompt):
return False
return (
self.name == other.name
and self.input_model == other.input_model
and self.output_model == other.output_model
and self.instruction == other.instruction
and self.examples == other.examples
and self.language == other.language
)

def save(self, file_path: str):
"""
Save the prompt to a file.
"""
data = {
"ragas_version": __version__,
"original_hash": (
hash(self) if self.original_hash is None else self.original_hash
),
"language": self.language,
"instruction": self.instruction,
"examples": [
{"input": example[0].model_dump(), "output": example[1].model_dump()}
for example in self.examples
],
}
if os.path.exists(file_path):
raise FileExistsError(f"The file '{file_path}' already exists.")
with open(file_path, "w") as f:
json.dump(data, f, indent=2)
print(f"Prompt saved to {file_path}")

@classmethod
def load(cls, file_path: str) -> "PydanticPrompt[InputModel, OutputModel]":
with open(file_path, "r") as f:
data = json.load(f)

# You might want to add version compatibility checks here
ragas_version = data.get("ragas_version")
if ragas_version != __version__:
logger.warning(
"Prompt was saved with Ragas v%s, but you are loading it with Ragas v%s. "
"There might be incompatibilities.",
ragas_version,
__version__,
)
original_hash = data.get("original_hash")

prompt = cls()
instruction = data["instruction"]
examples = [
(
prompt.input_model(**example["input"]),
prompt.output_model(**example["output"]),
)
for example in data["examples"]
]

prompt.instruction = instruction
prompt.examples = examples
prompt.language = data.get("language", prompt.language)

# Optionally, verify the loaded prompt's hash matches the saved hash
if original_hash is not None and hash(prompt) != original_hash:
logger.warning("Loaded prompt hash does not match the saved hash.")

return prompt


# Ragas Output Parser
class OutputStringAndPrompt(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions src/ragas/testset/synthesizers/abstract_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AbstractQueryFromTheme,
CAQInput,
CommonConceptsFromKeyphrases,
CommonThemeFromSummaries,
CommonThemeFromSummariesPrompt,
ComparativeAbstractQuery,
Concepts,
KeyphrasesAndNumConcepts,
Expand All @@ -44,7 +44,7 @@ class AbstractQuerySynthesizer(QuerySynthesizer):

def __post_init__(self):
super().__post_init__()
self.common_theme_prompt = CommonThemeFromSummaries()
self.common_theme_prompt = CommonThemeFromSummariesPrompt()

async def _generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/testset/synthesizers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Themes(BaseModel):
themes: t.List[Theme]


class CommonThemeFromSummaries(PydanticPrompt[Summaries, Themes]):
class CommonThemeFromSummariesPrompt(PydanticPrompt[Summaries, Themes]):
input_model = Summaries
output_model = Themes
instruction = "Analyze the following summaries and identify given number of common themes. The themes should be concise, descriptive, and highlight a key aspect shared across the summaries."
Expand Down
32 changes: 17 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,24 @@ def pytest_configure(config):
)


class FakeTestLLM(BaseRagasLLM):
def llm(self):
return self

def generate_text(
self, prompt: PromptValue, n=1, temperature=1e-8, stop=None, callbacks=[]
):
generations = [[Generation(text=prompt.prompt_str)] * n]
return LLMResult(generations=generations)

async def agenerate_text(
self, prompt: PromptValue, n=1, temperature=1e-8, stop=None, callbacks=[]
):
return self.generate_text(prompt, n, temperature, stop, callbacks)
class EchoLLM(BaseRagasLLM):
def generate_text( # type: ignore
self,
prompt: PromptValue,
*args,
**kwargs,
) -> LLMResult:
return LLMResult(generations=[[Generation(text=prompt.to_string())]])

async def agenerate_text( # type: ignore
self,
prompt: PromptValue,
*args,
**kwargs,
) -> LLMResult:
return LLMResult(generations=[[Generation(text=prompt.to_string())]])


@pytest.fixture
def fake_llm():
return FakeTestLLM()
return EchoLLM()
48 changes: 48 additions & 0 deletions tests/unit/prompt/test_prompt_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

from ragas.testset.synthesizers import AbstractQuerySynthesizer


def test_prompt_save_load(tmp_path, fake_llm):
synth = AbstractQuerySynthesizer(llm=fake_llm)
synth_prompts = synth.get_prompts()
synth.save_prompts(tmp_path)
loaded_prompts = synth.load_prompts(tmp_path)
assert len(synth_prompts) == len(loaded_prompts)
for name, prompt in synth_prompts.items():
assert name in loaded_prompts
assert prompt == loaded_prompts[name]


@pytest.mark.asyncio
async def test_prompt_save_adapt_load(tmp_path, fake_llm):
synth = AbstractQuerySynthesizer(llm=fake_llm)

# patch adapt_prompts
async def adapt_prompts_patched(self, language, llm):
for prompt in self.get_prompts().values():
prompt.instruction = "test"
prompt.language = language
return self.get_prompts()

synth.adapt_prompts = adapt_prompts_patched.__get__(synth)

# adapt prompts
original_prompts = synth.get_prompts()
adapted_prompts = await synth.adapt_prompts("spanish", fake_llm)
synth.set_prompts(**adapted_prompts)

# save n load
synth.save_prompts(tmp_path)
loaded_prompts = synth.load_prompts(tmp_path, language="spanish")

# check conditions
assert len(adapted_prompts) == len(loaded_prompts)
for name, adapted_prompt in adapted_prompts.items():
assert name in loaded_prompts
assert name in original_prompts

loaded_prompt = loaded_prompts[name]
assert adapted_prompt.instruction == loaded_prompt.instruction
assert adapted_prompt.language == loaded_prompt.language
assert adapted_prompt == loaded_prompt
Loading

0 comments on commit d6b9e75

Please sign in to comment.