diff --git a/src/ragas/metrics/__init__.py b/src/ragas/metrics/__init__.py index ebf92ebbc..77b6e7b60 100644 --- a/src/ragas/metrics/__init__.py +++ b/src/ragas/metrics/__init__.py @@ -62,7 +62,7 @@ StringPresence, ) from ragas.metrics._summarization import SummarizationScore, summarization_score -from ragas.metrics._tool_call_accuracy import ToolCallAccuracy +from ragas.metrics._tool_call_accuracy import ToolCallAccuracy, ToolCallParallelAccuracy from ragas.metrics._topic_adherence import TopicAdherenceScore __all__ = [ @@ -108,6 +108,7 @@ "AgentGoalAccuracyWithoutReference", "AgentGoalAccuracyWithReference", "ToolCallAccuracy", + "ToolCallParallelAccuracy", "ResponseRelevancy", "SemanticSimilarity", "DistanceMeasure", diff --git a/src/ragas/metrics/_tool_call_accuracy.py b/src/ragas/metrics/_tool_call_accuracy.py index b89a23b49..695a22bbf 100644 --- a/src/ragas/metrics/_tool_call_accuracy.py +++ b/src/ragas/metrics/_tool_call_accuracy.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from ragas.dataset_schema import MultiTurnSample, SingleTurnSample -from ragas.messages import AIMessage +from ragas.messages import AIMessage, ToolCall from ragas.metrics._string import ExactMatch from ragas.metrics.base import MetricType, MultiTurnMetric, SingleTurnMetric @@ -33,7 +33,7 @@ def init(self, run_config): pass async def _get_arg_score( - self, preds: t.Dict[str, t.Any], refs: t.Dict[str, t.Any], callbacks: Callbacks + self, preds: t.Dict[str, t.Any], refs: t.Dict[str, t.Any], callbacks: Callbacks ) -> float: score = 0.0 for arg in refs.keys(): @@ -48,7 +48,7 @@ async def _get_arg_score( return score / len(refs.keys()) def is_sequence_aligned( - self, pred_sequence: t.List[str], ref_sequence: t.List[str] + self, pred_sequence: t.List[str], ref_sequence: t.List[str] ) -> bool: ref_index = 0 # Index to track position in reference sequence for pred in pred_sequence: @@ -82,13 +82,12 @@ async def _multi_turn_ascore( if pred_tool_calls: score = 0.0 reference_tool_calls = sample.reference_tool_calls - for ref_tool_call in reference_tool_calls: - for pred_tool_call in pred_tool_calls: - if ref_tool_call.name == pred_tool_call.name: - arg_score = await self._get_arg_score( - pred_tool_call.args, ref_tool_call.args, callbacks - ) - score += arg_score + for ref_tool_call, pred_tool_call in zip(reference_tool_calls, pred_tool_calls): + if ref_tool_call.name == pred_tool_call.name: + arg_score = await self._get_arg_score( + pred_tool_call.args, ref_tool_call.args, callbacks + ) + score += arg_score score /= len(reference_tool_calls) else: @@ -99,3 +98,59 @@ async def _multi_turn_ascore( async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: return await self._multi_turn_ascore(MultiTurnSample(**row), callbacks) + + +@dataclass +class ToolCallParallelAccuracy(ToolCallAccuracy): + name: str = "tool_call_parallel_accuracy" + + @staticmethod + def _sorted_key_for_tool_call(tc: ToolCall): + key_list = [tc.name] + args = tc.args + args_name = sorted(args) + for name in args_name: + key_list.append(name) + key_list.append(str(args[name])) + + return tuple(key_list) + + async def _multi_turn_ascore( + self, sample: MultiTurnSample, callbacks: Callbacks + ) -> float: + assert (sample.reference_tool_calls is not None), "Reference tool calls is not set" + + pred_tool_calls = [] + for item in sample.user_input: + if isinstance(item, AIMessage) and item.tool_calls is not None: + pred_tool_calls.extend(item.tool_calls) + + # Sort the tool calls first. + pred_tool_calls = sorted(pred_tool_calls, key=self._sorted_key_for_tool_call) + reference_tool_calls = sorted(sample.reference_tool_calls, key=self._sorted_key_for_tool_call) + + tool_call_pred_sequence = [tool_call.name for tool_call in pred_tool_calls] + tool_call_ref_sequence = [ + tool_call.name for tool_call in reference_tool_calls + ] + + sequence_aligned = int( + self.is_sequence_aligned(tool_call_pred_sequence, tool_call_ref_sequence) + ) + + if pred_tool_calls: + score = 0.0 + reference_tool_calls = reference_tool_calls + for ref_tool_call, pred_tool_call in zip(reference_tool_calls, pred_tool_calls): + if ref_tool_call.name == pred_tool_call.name: + arg_score = await self._get_arg_score( + pred_tool_call.args, ref_tool_call.args, callbacks + ) + score += arg_score + + score /= len(reference_tool_calls) + else: + warnings.warn("No tool calls found in the user input") + return 0.0 + + return score * sequence_aligned