Skip to content

Commit 1b8376a

Browse files
authored
Migrate answer_correctness (#2365)
1 parent 07e530b commit 1b8376a

File tree

7 files changed

+866
-8
lines changed

7 files changed

+866
-8
lines changed

src/ragas/llms/base.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class LangchainLLMWrapper(BaseRagasLLM):
144144

145145
def __init__(
146146
self,
147-
langchain_llm: BaseLanguageModel[BaseMessage],
147+
langchain_llm: BaseLanguageModel,
148148
run_config: t.Optional[RunConfig] = None,
149149
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
150150
cache: t.Optional[CacheInterface] = None,
@@ -491,6 +491,13 @@ def llm_factory(
491491
# Experimental LLM classes migrated from ragas.experimental.llms
492492

493493

494+
class InstructorModelArgs(BaseModel):
495+
"""Simple model arguments configuration for instructor LLMs"""
496+
497+
temperature: float = 0.01
498+
top_p: float = 0.1
499+
500+
494501
class InstructorBaseRagasLLM(ABC):
495502
"""Base class for LLMs using the Instructor library pattern."""
496503

@@ -505,19 +512,35 @@ def generate(
505512

506513
@abstractmethod
507514
async def agenerate(
508-
self, prompt: str, response_model: t.Type[InstructorTypeVar]
515+
self,
516+
prompt: str,
517+
response_model: t.Type[InstructorTypeVar],
509518
) -> InstructorTypeVar:
510519
"""Asynchronously generate a response using the configured LLM."""
511520

512521

513522
class InstructorLLM(InstructorBaseRagasLLM):
514523
"""LLM wrapper using the Instructor library for structured outputs."""
515524

516-
def __init__(self, client: t.Any, model: str, provider: str, **model_args):
525+
def __init__(
526+
self,
527+
client: t.Any,
528+
model: str,
529+
provider: str,
530+
model_args: t.Optional[InstructorModelArgs] = None,
531+
**kwargs,
532+
):
517533
self.client = client
518534
self.model = model
519535
self.provider = provider
520-
self.model_args = model_args or {}
536+
537+
# Use deterministic defaults if no model_args provided
538+
if model_args is None:
539+
model_args = InstructorModelArgs()
540+
541+
# Convert to dict and merge with any additional kwargs
542+
self.model_args = {**model_args.model_dump(), **kwargs}
543+
521544
# Check if client is async-capable at initialization
522545
self.is_async = self._check_client_async()
523546

@@ -624,7 +647,9 @@ def generate(
624647
return result
625648

626649
async def agenerate(
627-
self, prompt: str, response_model: t.Type[InstructorTypeVar]
650+
self,
651+
prompt: str,
652+
response_model: t.Type[InstructorTypeVar],
628653
) -> InstructorTypeVar:
629654
"""Asynchronously generate a response using the configured LLM."""
630655
messages = [{"role": "user", "content": prompt}]
@@ -789,6 +814,13 @@ def _initialize_client(provider: str, client: t.Any) -> t.Any:
789814
)
790815
)
791816

817+
# Create model args with deterministic defaults, allowing override via kwargs
818+
model_args = InstructorModelArgs()
819+
792820
return InstructorLLM(
793-
client=instructor_patched_client, model=model, provider=provider, **kwargs
821+
client=instructor_patched_client,
822+
model=model,
823+
provider=provider,
824+
model_args=model_args,
825+
**kwargs,
794826
)

src/ragas/metrics/collections/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Collections of metrics using modern component architecture."""
22

3+
from ragas.metrics.collections._answer_correctness import AnswerCorrectness
34
from ragas.metrics.collections._answer_relevancy import AnswerRelevancy
45
from ragas.metrics.collections._answer_similarity import AnswerSimilarity
56
from ragas.metrics.collections._bleu_score import BleuScore
@@ -15,6 +16,7 @@
1516

1617
__all__ = [
1718
"BaseMetric", # Base class
19+
"AnswerCorrectness",
1820
"AnswerRelevancy",
1921
"AnswerSimilarity",
2022
"BleuScore",
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""Answer Correctness metric v2 - Modern implementation with function-based prompts."""
2+
3+
import typing as t
4+
from typing import List
5+
6+
import numpy as np
7+
from pydantic import BaseModel
8+
9+
from ragas.metrics.collections.base import BaseMetric
10+
from ragas.metrics.result import MetricResult
11+
from ragas.prompt.metrics.answer_correctness import (
12+
correctness_classifier_prompt,
13+
statement_generator_prompt,
14+
)
15+
16+
if t.TYPE_CHECKING:
17+
from ragas.embeddings.base import BaseRagasEmbedding
18+
from ragas.llms.base import InstructorBaseRagasLLM
19+
20+
21+
class StatementGeneratorOutput(BaseModel):
22+
"""Structured output for statement generation."""
23+
24+
statements: List[str]
25+
26+
27+
class StatementsWithReason(BaseModel):
28+
"""Individual statement with reasoning for classification."""
29+
30+
statement: str
31+
reason: str
32+
33+
34+
class ClassificationWithReason(BaseModel):
35+
"""Structured output for TP/FP/FN classification."""
36+
37+
TP: List[StatementsWithReason]
38+
FP: List[StatementsWithReason]
39+
FN: List[StatementsWithReason]
40+
41+
42+
class AnswerCorrectness(BaseMetric):
43+
"""
44+
Modern v2 implementation of answer correctness evaluation.
45+
46+
Measures answer correctness as a weighted combination of:
47+
- Factuality: F1 score from statement-level TP/FP/FN classification
48+
- Similarity: Semantic similarity between answer and reference
49+
50+
This implementation uses modern instructor LLMs with structured output and modern embeddings.
51+
Only supports modern components - legacy wrappers are rejected with clear error messages.
52+
53+
Usage:
54+
>>> import instructor
55+
>>> from openai import AsyncOpenAI
56+
>>> from ragas.llms.base import instructor_llm_factory
57+
>>> from ragas.embeddings.base import embedding_factory
58+
>>> from ragas.metrics.collections import AnswerCorrectness
59+
>>>
60+
>>> # Setup dependencies
61+
>>> client = AsyncOpenAI()
62+
>>> llm = instructor_llm_factory("openai", client=client, model="gpt-4o-mini")
63+
>>> embeddings = embedding_factory("openai", model="text-embedding-ada-002", client=client, interface="modern")
64+
>>>
65+
>>> # Create metric instance
66+
>>> metric = AnswerCorrectness(llm=llm, embeddings=embeddings)
67+
>>>
68+
>>> # Single evaluation
69+
>>> result = await metric.ascore(
70+
... user_input="What is the capital of France?",
71+
... response="Paris is the capital of France and has many museums.",
72+
... reference="Paris is the capital of France."
73+
... )
74+
>>> print(f"Correctness Score: {result.value}")
75+
>>>
76+
>>> # Custom weights (more factuality focus)
77+
>>> factual_metric = AnswerCorrectness(
78+
... llm=llm,
79+
... embeddings=embeddings,
80+
... weights=[0.9, 0.1]
81+
... )
82+
83+
Attributes:
84+
llm: Modern instructor-based LLM for statement generation and classification
85+
embeddings: Modern embeddings model for similarity calculation
86+
name: The metric name
87+
weights: [factuality_weight, similarity_weight] - must sum to > 0
88+
beta: F-beta score parameter (β>1 favors recall, β<1 favors precision)
89+
allowed_values: Score range (0.0 to 1.0)
90+
"""
91+
92+
# Type hints for linter (attributes are set in __init__)
93+
llm: "InstructorBaseRagasLLM"
94+
embeddings: "BaseRagasEmbedding"
95+
96+
def __init__(
97+
self,
98+
llm: "InstructorBaseRagasLLM",
99+
embeddings: "BaseRagasEmbedding",
100+
name: str = "answer_correctness",
101+
weights: List[float] = [0.75, 0.25],
102+
beta: float = 1.0,
103+
**kwargs,
104+
):
105+
"""
106+
Initialize AnswerCorrectness metric with required components.
107+
108+
Args:
109+
llm: Modern instructor-based LLM for statement generation and classification
110+
embeddings: Modern embeddings model for similarity calculation
111+
weights: [factuality_weight, similarity_weight]. Must sum to > 0.
112+
beta: F-beta score parameter. β>1 favors recall, β<1 favors precision.
113+
"""
114+
# Set attributes explicitly before calling super()
115+
self.llm = llm
116+
self.embeddings = embeddings
117+
self.weights = weights
118+
self.beta = beta
119+
120+
# Validate weights
121+
if len(weights) != 2:
122+
raise ValueError(
123+
"Expects a list of two weights. First for factuality, second for semantic similarity"
124+
)
125+
if all([w == 0 for w in weights]):
126+
raise ValueError("At least one weight must be non-zero")
127+
if not all([w >= 0 for w in weights]):
128+
raise ValueError("Weights must be non-negative")
129+
130+
# Validate beta
131+
if not isinstance(beta, float):
132+
raise ValueError(
133+
"Beta must be a float. A beta > 1 gives more weight to recall, while beta < 1 favors precision."
134+
)
135+
136+
# Call super() for validation (without passing llm/embeddings in kwargs)
137+
super().__init__(name=name, **kwargs)
138+
139+
async def ascore(
140+
self, user_input: str, response: str, reference: str
141+
) -> MetricResult:
142+
"""
143+
Calculate answer correctness score.
144+
145+
Components are guaranteed to be validated and non-None by the base class.
146+
147+
Args:
148+
user_input: The original question
149+
response: The answer to evaluate
150+
reference: The ground truth reference
151+
152+
Returns:
153+
MetricResult with correctness score (0.0-1.0)
154+
"""
155+
# Step 1: Generate statements from both response and reference
156+
response_statements = await self._generate_statements(user_input, response)
157+
reference_statements = await self._generate_statements(user_input, reference)
158+
159+
# Step 2: Calculate factuality score via TP/FP/FN classification
160+
if response_statements and reference_statements:
161+
classification = await self._classify_statements(
162+
user_input, response_statements, reference_statements
163+
)
164+
factuality_score = self._compute_f1_score(classification)
165+
else:
166+
# If no statements generated, assume perfect match
167+
factuality_score = 1.0
168+
169+
# Step 3: Calculate semantic similarity score
170+
if self.weights[1] == 0:
171+
similarity_score = 0.0
172+
else:
173+
similarity_score = await self._calculate_similarity(response, reference)
174+
175+
# Step 4: Combine scores with weighted average
176+
final_score = np.average(
177+
[factuality_score, similarity_score],
178+
weights=self.weights,
179+
)
180+
181+
return MetricResult(value=float(final_score))
182+
183+
async def _generate_statements(self, question: str, text: str) -> List[str]:
184+
"""Generate atomic statements from text using the statement generator prompt."""
185+
prompt = statement_generator_prompt(question, text)
186+
# Use deterministic defaults set in LLM constructor
187+
result = await self.llm.agenerate(prompt, StatementGeneratorOutput)
188+
return result.statements
189+
190+
async def _classify_statements(
191+
self,
192+
question: str,
193+
answer_statements: List[str],
194+
ground_truth_statements: List[str],
195+
) -> ClassificationWithReason:
196+
"""Classify statements as TP/FP/FN using the correctness classifier prompt with strict behavior."""
197+
prompt = correctness_classifier_prompt(
198+
question, answer_statements, ground_truth_statements
199+
)
200+
# Use deterministic defaults set in LLM constructor
201+
classification = await self.llm.agenerate(prompt, ClassificationWithReason)
202+
return classification
203+
204+
def _compute_f1_score(self, classification: ClassificationWithReason) -> float:
205+
"""Compute F1 score from TP/FP/FN classification."""
206+
tp = len(classification.TP)
207+
fp = len(classification.FP)
208+
fn = len(classification.FN)
209+
210+
# Calculate precision and recall
211+
if tp + fp == 0:
212+
precision = 1.0 if fn == 0 else 0.0
213+
else:
214+
precision = tp / (tp + fp)
215+
216+
if tp + fn == 0:
217+
recall = 1.0 if fp == 0 else 0.0
218+
else:
219+
recall = tp / (tp + fn)
220+
221+
# Calculate F-beta score
222+
if precision + recall == 0:
223+
return 0.0
224+
225+
beta_squared = self.beta**2
226+
f_score = (
227+
(1 + beta_squared)
228+
* (precision * recall)
229+
/ (beta_squared * precision + recall)
230+
)
231+
232+
return float(f_score)
233+
234+
async def _calculate_similarity(self, response: str, reference: str) -> float:
235+
"""Calculate semantic similarity between response and reference using embeddings."""
236+
# Get embeddings for both texts
237+
response_embedding = np.asarray(
238+
await self.embeddings.aembed_text(response)
239+
).reshape(1, -1)
240+
reference_embedding = np.asarray(
241+
await self.embeddings.aembed_text(reference)
242+
).reshape(1, -1)
243+
244+
# Calculate cosine similarity
245+
norm_response = np.linalg.norm(response_embedding, axis=1)
246+
norm_reference = np.linalg.norm(reference_embedding, axis=1)
247+
248+
if norm_response == 0 or norm_reference == 0:
249+
return 0.0
250+
251+
cosine_similarity = np.dot(response_embedding, reference_embedding.T)[0, 0] / (
252+
norm_response[0] * norm_reference[0]
253+
)
254+
255+
return float(cosine_similarity)
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
"""Metric-specific prompts for Ragas evaluation metrics."""
22

3+
from ragas.prompt.metrics.answer_correctness import (
4+
correctness_classifier_prompt,
5+
statement_generator_prompt,
6+
)
37
from ragas.prompt.metrics.answer_relevance import answer_relevancy_prompt
48

5-
__all__ = ["answer_relevancy_prompt"]
9+
__all__ = [
10+
"answer_relevancy_prompt",
11+
"correctness_classifier_prompt",
12+
"statement_generator_prompt",
13+
]

0 commit comments

Comments
 (0)