Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tool call accuracy. #1665

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ragas/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
StringPresence,
)
from ragas.metrics._summarization import SummarizationScore, summarization_score
from ragas.metrics._tool_call_accuracy import ToolCallAccuracy
from ragas.metrics._tool_call_accuracy import ToolCallAccuracy, ToolCallParallelAccuracy
from ragas.metrics._topic_adherence import TopicAdherenceScore

__all__ = [
Expand Down Expand Up @@ -108,6 +108,7 @@
"AgentGoalAccuracyWithoutReference",
"AgentGoalAccuracyWithReference",
"ToolCallAccuracy",
"ToolCallParallelAccuracy",
"ResponseRelevancy",
"SemanticSimilarity",
"DistanceMeasure",
Expand Down
75 changes: 65 additions & 10 deletions src/ragas/metrics/_tool_call_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field

from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
from ragas.messages import AIMessage
from ragas.messages import AIMessage, ToolCall
from ragas.metrics._string import ExactMatch
from ragas.metrics.base import MetricType, MultiTurnMetric, SingleTurnMetric

Expand Down Expand Up @@ -33,7 +33,7 @@ def init(self, run_config):
pass

async def _get_arg_score(
self, preds: t.Dict[str, t.Any], refs: t.Dict[str, t.Any], callbacks: Callbacks
self, preds: t.Dict[str, t.Any], refs: t.Dict[str, t.Any], callbacks: Callbacks
) -> float:
score = 0.0
for arg in refs.keys():
Expand All @@ -48,7 +48,7 @@ async def _get_arg_score(
return score / len(refs.keys())

def is_sequence_aligned(
self, pred_sequence: t.List[str], ref_sequence: t.List[str]
self, pred_sequence: t.List[str], ref_sequence: t.List[str]
) -> bool:
ref_index = 0 # Index to track position in reference sequence
for pred in pred_sequence:
Expand Down Expand Up @@ -82,13 +82,12 @@ async def _multi_turn_ascore(
if pred_tool_calls:
score = 0.0
reference_tool_calls = sample.reference_tool_calls
for ref_tool_call in reference_tool_calls:
for pred_tool_call in pred_tool_calls:
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score
for ref_tool_call, pred_tool_call in zip(reference_tool_calls, pred_tool_calls):
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score

score /= len(reference_tool_calls)
else:
Expand All @@ -99,3 +98,59 @@ async def _multi_turn_ascore(

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await self._multi_turn_ascore(MultiTurnSample(**row), callbacks)


@dataclass
class ToolCallParallelAccuracy(ToolCallAccuracy):
name: str = "tool_call_parallel_accuracy"

@staticmethod
def _sorted_key_for_tool_call(tc: ToolCall):
key_list = [tc.name]
args = tc.args
args_name = sorted(args)
for name in args_name:
key_list.append(name)
key_list.append(str(args[name]))

return tuple(key_list)

async def _multi_turn_ascore(
self, sample: MultiTurnSample, callbacks: Callbacks
) -> float:
assert (sample.reference_tool_calls is not None), "Reference tool calls is not set"

pred_tool_calls = []
for item in sample.user_input:
if isinstance(item, AIMessage) and item.tool_calls is not None:
pred_tool_calls.extend(item.tool_calls)

# Sort the tool calls first.
pred_tool_calls = sorted(pred_tool_calls, key=self._sorted_key_for_tool_call)
reference_tool_calls = sorted(sample.reference_tool_calls, key=self._sorted_key_for_tool_call)

tool_call_pred_sequence = [tool_call.name for tool_call in pred_tool_calls]
tool_call_ref_sequence = [
tool_call.name for tool_call in reference_tool_calls
]

sequence_aligned = int(
self.is_sequence_aligned(tool_call_pred_sequence, tool_call_ref_sequence)
)

if pred_tool_calls:
score = 0.0
reference_tool_calls = reference_tool_calls
for ref_tool_call, pred_tool_call in zip(reference_tool_calls, pred_tool_calls):
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score

score /= len(reference_tool_calls)
else:
warnings.warn("No tool calls found in the user input")
return 0.0

return score * sequence_aligned
Loading