Skip to content

Commit

Permalink
docs: added api docs for testset generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan committed Oct 13, 2024
1 parent 2ac2dc1 commit ac92103
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/references/generate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragas.testset.synthesizers.generate
1 change: 1 addition & 0 deletions docs/references/synthesizers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragas.testset.synthesizers
1 change: 1 addition & 0 deletions docs/references/transforms.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragas.testset.transforms
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ nav:
- Testset Generation:
- Schemas: references/testset_schema.md
- Graph: references/graph.md
- Transforms: references/transforms.md
- Synthesizers: references/synthesizers.md
- Generation: references/generate.md
- Integrations: references/integrations.md
- ❤️ Community: community/index.md

Expand Down
29 changes: 27 additions & 2 deletions src/ragas/prompt/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,30 @@


class PromptMixin:
"""
Mixin class for classes that have prompts.
eg: [BaseSynthesizer][ragas.testset.synthesizers.base.BaseSynthesizer], [MetricWithLLM][ragas.metrics.base.MetricWithLLM]
"""

def get_prompts(self) -> t.Dict[str, PydanticPrompt]:
"""
Returns a dictionary of prompts for the class.
"""
prompts = {}
for name, value in inspect.getmembers(self):
if isinstance(value, PydanticPrompt):
prompts.update({name: value})
return prompts

def set_prompts(self, **prompts):
"""
Sets the prompts for the class.
Raises
------
ValueError
If the prompt is not an instance of `PydanticPrompt`.
"""
available_prompts = self.get_prompts()
for key, value in prompts.items():
if key not in available_prompts:
Expand All @@ -39,6 +55,15 @@ def set_prompts(self, **prompts):
async def adapt_prompts(
self, language: str, llm: BaseRagasLLM
) -> t.Dict[str, PydanticPrompt]:
"""
Adapts the prompts in the class to the given language and using the given LLM.
Notes
-----
Make sure you use the best available LLM for adapting the prompts and then save and load the prompts using
[save_prompts][ragas.prompt.mixin.PromptMixin.save_prompts] and [load_prompts][ragas.prompt.mixin.PromptMixin.load_prompts]
methods.
"""
prompts = self.get_prompts()
adapted_prompts = {}
for name, prompt in prompts.items():
Expand All @@ -49,7 +74,7 @@ async def adapt_prompts(

def save_prompts(self, path: str):
"""
save prompts to a directory in the format of {name}_{language}.json
Saves the prompts to a directory in the format of {name}_{language}.json
"""
# check if path is valid
if not os.path.exists(path):
Expand All @@ -65,7 +90,7 @@ def save_prompts(self, path: str):

def load_prompts(self, path: str, language: t.Optional[str] = None):
"""
Load prompts from a directory in the format of {name}_{language}.json
Loads the prompts from a path. File should be in the format of {name}_{language}.json
"""
# check if path is valid
if not os.path.exists(path):
Expand Down
17 changes: 16 additions & 1 deletion src/ragas/testset/synthesizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,32 @@
ComparativeAbstractQuerySynthesizer,
)
from .base import BaseSynthesizer
from .base_query import QuerySynthesizer
from .specific_query import SpecificQuerySynthesizer

QueryDistribution = t.List[t.Tuple[BaseSynthesizer, float]]


def default_query_distribution(llm: BaseRagasLLM) -> QueryDistribution:
"""
Default query distribution for the test set.
By default, 25% of the queries are generated using `AbstractQuerySynthesizer`,
25% are generated using `ComparativeAbstractQuerySynthesizer`, and 50% are
generated using `SpecificQuerySynthesizer`.
"""
return [
(AbstractQuerySynthesizer(llm=llm), 0.25),
(ComparativeAbstractQuerySynthesizer(llm=llm), 0.25),
(SpecificQuerySynthesizer(llm=llm), 0.5),
]


__all__ = ["AbstractQuerySynthesizer", "default_query_distribution"]
__all__ = [
"BaseSynthesizer",
"QuerySynthesizer",
"AbstractQuerySynthesizer",
"ComparativeAbstractQuerySynthesizer",
"SpecificQuerySynthesizer",
"default_query_distribution",
]
22 changes: 22 additions & 0 deletions src/ragas/testset/synthesizers/abstract_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ class AbstractQueryScenario(BaseScenario):

@dataclass
class AbstractQuerySynthesizer(QuerySynthesizer):
"""
Synthesizes abstract queries which generate a theme and a set of summaries from a
cluster of chunks and then generate queries based on that.
Attributes
----------
generate_user_input_prompt : PydanticPrompt
The prompt used for generating the user input.
"""

generate_user_input_prompt: PydanticPrompt = field(
default_factory=AbstractQueryFromTheme
)
Expand Down Expand Up @@ -180,6 +190,18 @@ class ComparativeAbstractQueryScenario(BaseScenario):

@dataclass
class ComparativeAbstractQuerySynthesizer(QuerySynthesizer):
"""
Synthesizes comparative abstract queries which generate a common concept and
a set of keyphrases and summaries and then generate queries based on that.
Attributes
----------
common_concepts_prompt : PydanticPrompt
The prompt used for generating common concepts.
generate_query_prompt : PydanticPrompt
The prompt used for generating the query.
"""

common_concepts_prompt: PydanticPrompt = field(
default_factory=CommonConceptsFromKeyphrases
)
Expand Down
4 changes: 4 additions & 0 deletions src/ragas/testset/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class BaseScenario(BaseModel):

@dataclass
class BaseSynthesizer(ABC, t.Generic[Scenario], PromptMixin):
"""
Base class for synthesizing scenarios and samples.
"""

name: str = ""
llm: BaseRagasLLM = field(default_factory=llm_factory)

Expand Down
13 changes: 13 additions & 0 deletions src/ragas/testset/synthesizers/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@

@dataclass
class QuerySynthesizer(BaseSynthesizer[Scenario]):
"""
Synthesizes Question-Answer pairs. Used as a base class for other query synthesizers.
Attributes
----------
critic_query_prompt : PydanticPrompt
The prompt used for criticizing the query.
query_modification_prompt : PydanticPrompt
The prompt used for modifying the query.
generate_reference_prompt : PydanticPrompt
The prompt used for generating the reference.
"""

critic_query_prompt: PydanticPrompt = field(default_factory=CriticUserInput)
query_modification_prompt: PydanticPrompt = field(default_factory=ModifyUserInput)
generate_reference_prompt: PydanticPrompt = field(default_factory=GenerateReference)
Expand Down
19 changes: 18 additions & 1 deletion src/ragas/testset/synthesizers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@

@dataclass
class TestsetGenerator:
"""
Generates an evaluation dataset based on given scenarios and parameters.
Attributes
----------
llm : BaseRagasLLM
The language model to use for the generation process.
knowledge_graph : KnowledgeGraph, default empty
The knowledge graph to use for the generation process.
"""

llm: BaseRagasLLM
knowledge_graph: KnowledgeGraph = field(default_factory=KnowledgeGraph)

Expand All @@ -37,7 +48,10 @@ def from_langchain(
cls,
llm: LangchainLLM,
knowledge_graph: t.Optional[KnowledgeGraph] = None,
):
) -> TestsetGenerator:
"""
Creates a `TestsetGenerator` from a Langchain LLMs.
"""
knowledge_graph = knowledge_graph or KnowledgeGraph()
return cls(LangchainLLMWrapper(llm), knowledge_graph)

Expand All @@ -52,6 +66,9 @@ def generate_with_langchain_docs(
with_debugging_logs=False,
raise_exceptions: bool = True,
) -> Testset:
"""
Generates an evaluation dataset based on given scenarios and parameters.
"""
transforms = transforms or default_transforms()

# convert the documents to Ragas nodes
Expand Down
5 changes: 5 additions & 0 deletions src/ragas/testset/synthesizers/specific_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class SpecificQueryScenario(BaseScenario):

@dataclass
class SpecificQuerySynthesizer(QuerySynthesizer):
"""
Synthesizes specific queries by choosing specific chunks and generating a
keyphrase from them and then generating queries based on that.
"""

generate_query_prompt: PydanticPrompt = field(default_factory=SpecificQuery)

async def _generate_scenarios(
Expand Down
40 changes: 39 additions & 1 deletion src/ragas/testset/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseGraphTransformation
from .base import BaseGraphTransformation, Extractor, RelationshipBuilder, Splitter
from .engine import Parallel, Transforms, apply_transforms, rollback_transforms
from .extractors import (
EmbeddingExtractor,
Expand All @@ -15,6 +15,28 @@


def default_transforms() -> Transforms:
"""
Creates and returns a default set of transforms for processing a knowledge graph.
This function defines a series of transformation steps to be applied to a
knowledge graph, including extracting summaries, keyphrases, titles,
headlines, and embeddings, as well as building similarity relationships
between nodes.
The transforms are applied in the following order:
1. Parallel extraction of summaries and headlines
2. Embedding of summaries for document nodes
3. Splitting of headlines
4. Parallel extraction of embeddings, keyphrases, and titles
5. Building cosine similarity relationships between nodes
6. Building cosine similarity relationships between summaries
Returns
-------
Transforms
A list of transformation steps to be applied to the knowledge graph.
"""
from ragas.testset.graph import NodeType

# define the transforms
Expand Down Expand Up @@ -46,10 +68,26 @@ def default_transforms() -> Transforms:


__all__ = [
# base
"BaseGraphTransformation",
"Extractor",
"RelationshipBuilder",
"Splitter",
# Transform Engine
"Parallel",
"Transforms",
"apply_transforms",
"rollback_transforms",
"default_transforms",
# extractors
"EmbeddingExtractor",
"HeadlinesExtractor",
"KeyphrasesExtractor",
"SummaryExtractor",
"TitleExtractor",
# relationship builders
"CosineSimilarityBuilder",
"SummaryCosineSimilarityBuilder",
# splitters
"HeadlineSplitter",
]
15 changes: 15 additions & 0 deletions src/ragas/testset/transforms/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@


class Parallel:
"""
Collection of transformations to be applied in parallel.
Examples
--------
>>> Parallel(HeadlinesExtractor(), SummaryExtractor())
"""

def __init__(self, *transformations: BaseGraphTransformation):
self.transformations = list(transformations)

Expand Down Expand Up @@ -112,5 +120,12 @@ def apply_transforms(


def rollback_transforms(kg: KnowledgeGraph, transforms: Transforms):
"""
Rollback a list of transformations from a knowledge graph.
Note
----
This is not yet implemented. Please open an issue if you need this feature.
"""
# this will allow you to roll back the transformations
raise NotImplementedError
21 changes: 21 additions & 0 deletions src/ragas/testset/transforms/extractors/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,32 @@

@dataclass
class EmbeddingExtractor(Extractor):
"""
A class for extracting embeddings from nodes in a knowledge graph.
Attributes
----------
property_name : str
The name of the property to store the embedding
embed_property_name : str
The name of the property containing the text to embed
embedding_model : BaseRagasEmbeddings
The embedding model used for generating embeddings
"""

property_name: str = "embedding"
embed_property_name: str = "page_content"
embedding_model: BaseRagasEmbeddings = field(default_factory=embedding_factory)

async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
"""
Extracts the embedding for a given node.
Raises
------
ValueError
If the property to be embedded is not a string.
"""
text = node.get_property(self.embed_property_name)
if not isinstance(text, str):
raise ValueError(
Expand Down
Loading

0 comments on commit ac92103

Please sign in to comment.