Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ragas/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from ragas.metrics._datacompy_score import DataCompyScore
from ragas.metrics._domain_specific_rubrics import RubricsScore
from ragas.metrics._factual_correctness import FactualCorrectness
from ragas.metrics._factual_correctness import FactualCorrectness, factual_correctness
from ragas.metrics._faithfulness import Faithfulness, FaithfulnesswithHHEM, faithfulness
from ragas.metrics._goal_accuracy import (
AgentGoalAccuracyWithoutReference,
Expand Down Expand Up @@ -110,6 +110,7 @@
"answer_correctness",
"Faithfulness",
"faithfulness",
"factual_correctness",
"FaithfulnesswithHHEM",
"AnswerSimilarity",
"answer_similarity",
Expand Down
193 changes: 62 additions & 131 deletions src/ragas/metrics/_answer_correctness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import logging
import typing as t
from dataclasses import dataclass, field
Expand All @@ -9,11 +10,7 @@

from ragas.dataset_schema import SingleTurnSample
from ragas.metrics._answer_similarity import AnswerSimilarity
from ragas.metrics._faithfulness import (
StatementGeneratorInput,
StatementGeneratorOutput,
StatementGeneratorPrompt,
)
from ragas.metrics._faithfulness import StatementGeneratorOutput
from ragas.metrics.base import (
MetricOutputType,
MetricType,
Expand All @@ -22,15 +19,20 @@
SingleTurnMetric,
)
from ragas.metrics.utils import fbeta_score
from ragas.prompt import PydanticPrompt
from ragas.prompt.metric_prompts import (
CORRECTNESS_CLASSIFIER_PROMPT,
STATEMENT_GENERATOR_PROMPT,
)
from ragas.run_config import RunConfig

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks

logger = logging.getLogger(__name__)


# ============================================================================
# PYDANTIC MODELS (No LangChain dependencies)
# ============================================================================


class QuestionAnswerGroundTruth(BaseModel):
question: str
answer: list[str]
Expand All @@ -48,93 +50,7 @@ class ClassificationWithReason(BaseModel):
FN: list[StatementsWithReason]


class CorrectnessClassifier(
PydanticPrompt[QuestionAnswerGroundTruth, ClassificationWithReason]
):
instruction = "Given a ground truth and an answer statements, analyze each statement and classify them in one of the following categories: TP (true positive): statements that are present in answer that are also directly supported by the one or more statements in ground truth, FP (false positive): statements present in the answer but not directly supported by any statement in ground truth, FN (false negative): statements found in the ground truth but not present in answer. Each statement can only belong to one of the categories. Provide a reason for each classification."
input_model = QuestionAnswerGroundTruth
output_model = ClassificationWithReason
examples = [
(
QuestionAnswerGroundTruth(
question="What powers the sun and what is its primary function?",
answer=[
"The sun is powered by nuclear fission, similar to nuclear reactors on Earth.",
"The primary function of the sun is to provide light to the solar system.",
],
ground_truth=[
"The sun is powered by nuclear fusion, where hydrogen atoms fuse to form helium.",
"This fusion process in the sun's core releases a tremendous amount of energy.",
"The energy from the sun provides heat and light, which are essential for life on Earth.",
"The sun's light plays a critical role in Earth's climate system.",
"Sunlight helps to drive the weather and ocean currents.",
],
),
ClassificationWithReason(
TP=[
StatementsWithReason(
statement="The primary function of the sun is to provide light to the solar system.",
reason="This statement is somewhat supported by the ground truth mentioning the sun providing light and its roles, though it focuses more broadly on the sun's energy.",
)
],
FP=[
StatementsWithReason(
statement="The sun is powered by nuclear fission, similar to nuclear reactors on Earth.",
reason="This statement is incorrect and contradicts the ground truth which states that the sun is powered by nuclear fusion.",
)
],
FN=[
StatementsWithReason(
statement="The sun is powered by nuclear fusion, where hydrogen atoms fuse to form helium.",
reason="This accurate description of the sun’s power source is not included in the answer.",
),
StatementsWithReason(
statement="This fusion process in the sun's core releases a tremendous amount of energy.",
reason="This process and its significance are not mentioned in the answer.",
),
StatementsWithReason(
statement="The energy from the sun provides heat and light, which are essential for life on Earth.",
reason="The answer only mentions light, omitting the essential aspects of heat and its necessity for life, which the ground truth covers.",
),
StatementsWithReason(
statement="The sun's light plays a critical role in Earth's climate system.",
reason="This broader impact of the sun’s light on Earth's climate system is not addressed in the answer.",
),
StatementsWithReason(
statement="Sunlight helps to drive the weather and ocean currents.",
reason="The effect of sunlight on weather patterns and ocean currents is omitted in the answer.",
),
],
),
),
(
QuestionAnswerGroundTruth(
question="What is the boiling point of water?",
answer=[
"The boiling point of water is 100 degrees Celsius at sea level"
],
ground_truth=[
"The boiling point of water is 100 degrees Celsius (212 degrees Fahrenheit) at sea level.",
"The boiling point of water can change with altitude.",
],
),
ClassificationWithReason(
TP=[
StatementsWithReason(
statement="The boiling point of water is 100 degrees Celsius at sea level",
reason="This statement is directly supported by the ground truth which specifies the boiling point of water as 100 degrees Celsius at sea level.",
)
],
FP=[],
FN=[
StatementsWithReason(
statement="The boiling point of water can change with altitude.",
reason="This additional information about how the boiling point of water can vary with altitude is not mentioned in the answer.",
)
],
),
),
]
# Prompts imported from centralized location


@dataclass
Expand All @@ -145,11 +61,8 @@ class AnswerCorrectness(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):

Attributes
----------
name: string
The name of the metrics
weights:
a list of two weights corresponding to factuality and semantic similarity
Defaults [0.75, 0.25]
List of two weights for factuality and semantic similarity [0.75, 0.25]
answer_similarity:
The AnswerSimilarity object
"""
Expand All @@ -161,10 +74,6 @@ class AnswerCorrectness(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
}
)
output_type = MetricOutputType.CONTINUOUS
correctness_prompt: PydanticPrompt = field(default_factory=CorrectnessClassifier)
statement_generator_prompt: PydanticPrompt = field(
default_factory=StatementGeneratorPrompt
)
weights: list[float] = field(default_factory=lambda: [0.75, 0.25])
beta: float = 1.0
answer_similarity: t.Optional[AnswerSimilarity] = None
Expand Down Expand Up @@ -200,50 +109,73 @@ def _compute_statement_presence(
return score

async def _create_simplified_statements(
self, question: str, text: str, callbacks: Callbacks
self, question: str, text: str
) -> StatementGeneratorOutput:
"""Generate statements from text using direct LLM call."""
assert self.llm is not None, "llm is not set"

prompt_input = StatementGeneratorInput(question=question, answer=text)
statements = await self.statement_generator_prompt.generate(
llm=self.llm,
data=prompt_input,
callbacks=callbacks,
prompt = STATEMENT_GENERATOR_PROMPT.format(question=question, answer=text)

# Use Instructor LLM interface for direct API calls without LangChain
result = self.llm.generate(
prompt,
response_model=StatementGeneratorOutput, # type: ignore
)

return statements
# Instructor returns structured objects directly - no JSON parsing needed!
return result

async def _classify_statements(
self, question: str, answer: list[str], ground_truth: list[str]
) -> ClassificationWithReason:
"""Classify statements using direct LLM call."""
assert self.llm is not None, "llm must be set to compute score"

answer_json = json.dumps(answer)
ground_truth_json = json.dumps(ground_truth)

prompt = CORRECTNESS_CLASSIFIER_PROMPT.format(
question=question,
answer_json=answer_json,
ground_truth_json=ground_truth_json,
)

# Use Instructor LLM interface for direct API calls without LangChain
result = self.llm.generate(
prompt,
response_model=ClassificationWithReason, # type: ignore
)

# Instructor returns structured objects directly - no JSON parsing needed!
return result

async def _single_turn_ascore(
self, sample: SingleTurnSample, callbacks: Callbacks
self, sample: SingleTurnSample, callbacks=None
) -> float:
"""Score a single turn sample (callbacks parameter kept for compatibility but ignored)."""
row = sample.to_dict()
score = await self._ascore(row, callbacks)
return score
return await self._ascore(row)

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
async def _ascore(self, row: t.Dict, callbacks=None) -> float:
"""
Calculate answer correctness score.
"""
assert self.llm is not None, "LLM must be set"

# extract the statements from the answer and the ground truth
question = row["user_input"]
statements: t.Dict[str, t.List[str]] = {}
for item in ["response", "reference"]:
statements_x = await self._create_simplified_statements(
question, row[item], callbacks
)
statements_x = statements_x.statements
statements[item] = statements_x
statements_x = await self._create_simplified_statements(question, row[item])
statements[item] = statements_x.statements

if not all([val == [] for val in statements.values()]):
ground_truth = [statement for statement in statements["reference"]]
answer = [statement for statement in statements["response"]]
answers = await self.correctness_prompt.generate(
llm=self.llm,
data=QuestionAnswerGroundTruth(
question=question,
answer=answer,
ground_truth=ground_truth,
),
callbacks=callbacks,
answers = await self._classify_statements(
question=question,
answer=answer,
ground_truth=ground_truth,
)
if answers is None:
return np.nan
Expand All @@ -257,9 +189,7 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
else:
assert self.answer_similarity is not None, "AnswerSimilarity must be set"

similarity_score = await self.answer_similarity.ascore(
row, callbacks=callbacks
)
similarity_score = await self.answer_similarity._ascore(row)

score = np.average(
[f1_score, similarity_score],
Expand All @@ -269,4 +199,5 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return float(score)


# Create default instance
answer_correctness = AnswerCorrectness()
14 changes: 5 additions & 9 deletions src/ragas/metrics/_answer_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
SingleTurnMetric,
)

if t.TYPE_CHECKING:
from langchain_core.callbacks.base import Callbacks


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -59,12 +55,12 @@ def __post_init__(self):
}

async def _single_turn_ascore(
self, sample: SingleTurnSample, callbacks: Callbacks
self, sample: SingleTurnSample, callbacks=None
) -> float:
row = sample.to_dict()
return await self._ascore(row, callbacks)
return await self._ascore(row)

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
async def _ascore(self, row: t.Dict, callbacks=None) -> float:
assert self.embeddings is not None, (
f"Error: '{self.name}' requires embeddings to be set."
)
Expand Down Expand Up @@ -109,8 +105,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
class AnswerSimilarity(SemanticSimilarity):
name: str = "answer_similarity"

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await super()._ascore(row, callbacks)
async def _ascore(self, row: t.Dict, callbacks=None) -> float:
return await super()._ascore(row)


answer_similarity = AnswerSimilarity()
Loading
Loading