Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add traces to EvaluationResult #1531

Merged
merged 9 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 133 additions & 25 deletions src/ragas/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
from __future__ import annotations

import json
import typing as t
import uuid
from dataclasses import dataclass, field
from enum import Enum

from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForChainGroup,
AsyncCallbackManagerForChainRun,
BaseCallbackHandler,
CallbackManager,
CallbackManagerForChainGroup,
CallbackManagerForChainRun,
Callbacks,
)
from pydantic import BaseModel, Field


def new_group(
name: str, inputs: t.Dict, callbacks: Callbacks
name: str,
inputs: t.Dict,
callbacks: Callbacks,
tags: t.Optional[t.List[str]] = None,
metadata: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Tuple[CallbackManagerForChainRun, CallbackManagerForChainGroup]:
tags = tags or []
metadata = metadata or {}

# start evaluation chain
if isinstance(callbacks, list):
cm = CallbackManager.configure(inheritable_callbacks=callbacks)
else:
cm = t.cast(CallbackManager, callbacks)
cm.tags = tags
cm.metadata = metadata
rm = cm.on_chain_start({"name": name}, inputs)
child_cm = rm.get_child()
group_cm = CallbackManagerForChainGroup(
Expand All @@ -35,24 +49,118 @@ def new_group(
return rm, group_cm


async def new_async_group(
name: str, inputs: t.Dict, callbacks: Callbacks
) -> t.Tuple[AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainGroup]:
# start evaluation chain
if isinstance(callbacks, list):
cm = AsyncCallbackManager.configure(inheritable_callbacks=callbacks)
else:
cm = t.cast(AsyncCallbackManager, callbacks)
rm = await cm.on_chain_start({"name": name}, inputs)
child_cm = rm.get_child()
group_cm = AsyncCallbackManagerForChainGroup(
child_cm.handlers,
child_cm.inheritable_handlers,
child_cm.parent_run_id,
parent_run_manager=rm,
tags=child_cm.tags,
inheritable_tags=child_cm.inheritable_tags,
metadata=child_cm.metadata,
inheritable_metadata=child_cm.inheritable_metadata,
)
return rm, group_cm
class ChainType(Enum):
EVALUATION = "evaluation"
METRIC = "metric"
ROW = "row"
RAGAS_PROMPT = "ragas_prompt"


class ChainRun(BaseModel):
run_id: uuid.UUID
parent_run_id: t.Optional[uuid.UUID]
name: str
inputs: t.Dict[str, t.Any]
metadata: t.Dict[str, t.Any]
outputs: t.Dict[str, t.Any] = Field(default_factory=dict)
children: t.List[uuid.UUID] = Field(default_factory=list)


class ChainRunEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, uuid.UUID):
return str(o)
if isinstance(o, ChainType):
return o.value
return json.JSONEncoder.default(self, o)


@dataclass
class RagasTracer(BaseCallbackHandler):
traces: t.Dict[uuid.UUID, ChainRun] = field(default_factory=dict)

def on_chain_start(
self,
serialized: t.Dict[str, t.Any],
inputs: t.Dict[str, t.Any],
*,
run_id: uuid.UUID,
parent_run_id: t.Optional[uuid.UUID] = None,
tags: t.Optional[t.List[str]] = None,
metadata: t.Optional[t.Dict[str, t.Any]] = None,
**kwargs: t.Any,
) -> t.Any:
self.traces[run_id] = ChainRun(
run_id=run_id,
parent_run_id=parent_run_id,
name=serialized["name"],
inputs=inputs,
metadata=metadata or {},
children=[],
)

if parent_run_id and parent_run_id in self.traces:
self.traces[parent_run_id].children.append(run_id)

def on_chain_end(
self,
outputs: t.Dict[str, t.Any],
*,
run_id: uuid.UUID,
**kwargs: t.Any,
) -> t.Any:
self.traces[run_id].outputs = outputs

def to_jsons(self) -> str:
return json.dumps(
[t.model_dump() for t in self.traces.values()],
indent=4,
cls=ChainRunEncoder,
)


@dataclass
class MetricTrace(dict):
scores: t.Dict[str, float] = field(default_factory=dict)

def __repr__(self):
return self.scores.__repr__()

def __str__(self):
return self.__repr__()


def parse_run_traces(
traces: t.Dict[uuid.UUID, ChainRun],
) -> t.List[t.Dict[str, t.Any]]:
root_traces = [
chain_trace
for chain_trace in traces.values()
if chain_trace.parent_run_id is None
]
if len(root_traces) > 1:
raise ValueError(
"Multiple root traces found! This is a bug on our end, please file an issue and we will fix it ASAP :)"
)
root_trace = root_traces[0]

# get all the row traces
parased_traces = []
for row_uuid in root_trace.children:
row_trace = traces[row_uuid]
metric_traces = MetricTrace()
for metric_uuid in row_trace.children:
metric_trace = traces[metric_uuid]
metric_traces.scores[metric_trace.name] = metric_trace.outputs["output"]
# get all the prompt IO from the metric trace
prompt_traces = {}
for i, prompt_uuid in enumerate(metric_trace.children):
prompt_trace = traces[prompt_uuid]
prompt_traces[f"{i}_{prompt_trace.name}"] = {
"input": prompt_trace.inputs["data"],
"output": prompt_trace.outputs["output"],
}
metric_traces[f"{metric_trace.name}"] = prompt_traces
parased_traces.append(metric_traces)

return parased_traces
35 changes: 13 additions & 22 deletions src/ragas/dataset_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
from datasets import Dataset as HFDataset
from pydantic import BaseModel, field_validator

from ragas.callbacks import parse_run_traces
from ragas.cost import CostCallbackHandler
from ragas.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from ragas.utils import safe_nanmean

if t.TYPE_CHECKING:
import uuid
from pathlib import Path

from datasets import Dataset as HFDataset
from pandas import DataFrame as PandasDataframe

from ragas.callbacks import ChainRun
from ragas.cost import TokenUsage


Expand Down Expand Up @@ -137,6 +140,7 @@ def pretty_repr(self):


Sample = t.TypeVar("Sample", bound=BaseSample)
T = t.TypeVar("T", bound="RagasDataset")


class RagasDataset(ABC, BaseModel, t.Generic[Sample]):
Expand All @@ -149,7 +153,7 @@ def to_list(self) -> t.List[t.Dict]:

@classmethod
@abstractmethod
def from_list(cls, data: t.List[t.Dict]) -> RagasDataset[Sample]:
def from_list(cls: t.Type[T], data: t.List[t.Dict]) -> T:
"""Creates an EvaluationDataset from a list of dictionaries."""
pass

Expand Down Expand Up @@ -181,7 +185,7 @@ def to_hf_dataset(self) -> HFDataset:
return HFDataset.from_list(self.to_list())

@classmethod
def from_hf_dataset(cls, dataset: HFDataset):
def from_hf_dataset(cls: t.Type[T], dataset: HFDataset) -> T:
"""Creates an EvaluationDataset from a Hugging Face Dataset."""
return cls.from_list(dataset.to_list())

Expand All @@ -202,7 +206,7 @@ def features(self):
return self.samples[0].get_features()

@classmethod
def from_dict(cls, mapping: t.Dict):
def from_dict(cls: t.Type[T], mapping: t.Dict) -> T:
"""Creates an EvaluationDataset from a dictionary."""
samples = []
if all(
Expand Down Expand Up @@ -237,7 +241,7 @@ def to_jsonl(self, path: t.Union[str, Path]):
jsonlfile.write(json.dumps(sample.to_dict(), ensure_ascii=False) + "\n")

@classmethod
def from_jsonl(cls, path: t.Union[str, Path]):
def from_jsonl(cls: t.Type[T], path: t.Union[str, Path]) -> T:
"""Creates an EvaluationDataset from a JSONL file."""
with open(path, "r") as jsonlfile:
data = [json.loads(line) for line in jsonlfile]
Expand Down Expand Up @@ -334,12 +338,6 @@ def from_list(cls, data: t.List[t.Dict]) -> EvaluationDataset:
return cls(samples=samples)


class EvaluationResultRow(BaseModel):
dataset_row: t.Dict
scores: t.Dict[str, t.Any]
trace: t.Dict[str, t.Any] = field(default_factory=dict) # none for now


@dataclass
class EvaluationResult:
"""
Expand All @@ -361,6 +359,8 @@ class EvaluationResult:
dataset: EvaluationDataset
binary_columns: t.List[str] = field(default_factory=list)
cost_cb: t.Optional[CostCallbackHandler] = None
traces: t.List[t.Dict[str, t.Any]] = field(default_factory=list)
ragas_traces: t.Dict[uuid.UUID, ChainRun] = field(default_factory=dict, repr=False)

def __post_init__(self):
# transform scores from list of dicts to dict of lists
Expand All @@ -377,6 +377,9 @@ def __post_init__(self):
value = t.cast(float, value)
values.append(value + 1e-10)

# parse the traces
self.traces = parse_run_traces(self.ragas_traces)

def to_pandas(self, batch_size: int | None = None, batched: bool = False):
"""
Convert the result to a pandas DataFrame.
Expand Down Expand Up @@ -413,18 +416,6 @@ def to_pandas(self, batch_size: int | None = None, batched: bool = False):
dataset_df = self.dataset.to_pandas()
return pd.concat([dataset_df, scores_df], axis=1)

def serialized(self) -> t.List[EvaluationResultRow]:
"""
Convert the result to a list of EvaluationResultRow.
"""
return [
EvaluationResultRow(
dataset_row=self.dataset[i].to_dict(),
scores=self.scores[i],
)
for i in range(len(self.scores))
]

def total_tokens(self) -> t.Union[t.List[TokenUsage], TokenUsage]:
"""
Compute the total tokens used in the evaluation.
Expand Down
13 changes: 11 additions & 2 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain_core.language_models import BaseLanguageModel as LangchainLLM

from ragas._analytics import EvaluationEvent, track, track_was_completed
from ragas.callbacks import new_group
from ragas.callbacks import ChainType, RagasTracer, new_group
from ragas.dataset_schema import (
EvaluationDataset,
EvaluationResult,
Expand Down Expand Up @@ -229,6 +229,10 @@ def evaluate(
# init the callbacks we need for various tasks
ragas_callbacks: t.Dict[str, BaseCallbackHandler] = {}

# Ragas Tracer which traces the run
tracer = RagasTracer()
ragas_callbacks["tracer"] = tracer

# check if cost needs to be calculated
if token_usage_parser is not None:
from ragas.cost import CostCallbackHandler
Expand All @@ -246,7 +250,10 @@ def evaluate(
# new evaluation chain
row_run_managers = []
evaluation_rm, evaluation_group_cm = new_group(
name=RAGAS_EVALUATION_CHAIN_NAME, inputs={}, callbacks=callbacks
name=RAGAS_EVALUATION_CHAIN_NAME,
inputs={},
callbacks=callbacks,
metadata={"type": ChainType.EVALUATION},
)

sample_type = dataset.get_sample_type()
Expand All @@ -256,6 +263,7 @@ def evaluate(
name=f"row {i}",
inputs=row,
callbacks=evaluation_group_cm,
metadata={"type": ChainType.ROW, "row_index": i},
)
row_run_managers.append((row_rm, row_group_cm))
if sample_type == SingleTurnSample:
Expand Down Expand Up @@ -321,6 +329,7 @@ def evaluate(
t.Union["CostCallbackHandler", None],
cost_cb,
),
ragas_traces=tracer.traces,
)
if not evaluation_group_cm.ended:
evaluation_rm.on_chain_end(result)
Expand Down
6 changes: 3 additions & 3 deletions src/ragas/metrics/_aspect_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class AspectCritic(MetricWithLLM, SingleTurnMetric, MultiTurnMetric):
strictness: int = field(default=1, repr=False)
max_retries: int = 1

def __post_init__(self: t.Self):
def __post_init__(self):
if self.name == "":
raise ValueError("Expects a name")
if self.definition == "":
Expand All @@ -141,12 +141,12 @@ def _compute_score(
return score

async def _single_turn_ascore(
self: t.Self, sample: SingleTurnSample, callbacks: Callbacks
self, sample: SingleTurnSample, callbacks: Callbacks
) -> float:
row = sample.to_dict()
return await self._ascore(row, callbacks)

async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.llm is not None, "set LLM before use"

user_input, context, response = (
Expand Down
4 changes: 2 additions & 2 deletions src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ async def _single_turn_ascore(
row = sample.to_dict()
return await self._ascore(row, callbacks)

async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
"""
returns the NLI score for each (q, c, a) pair
"""
Expand Down Expand Up @@ -330,7 +330,7 @@ def _create_batch(
for ndx in range(0, length_of_pairs, self.batch_size):
yield pairs[ndx : min(ndx + self.batch_size, length_of_pairs)]

async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
"""
returns the NLI score for each (q, c, a) pair
"""
Expand Down
Loading
Loading