Skip to content

Commit

Permalink
fix: type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan committed Oct 18, 2024
1 parent d79791e commit cd359c3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
16 changes: 10 additions & 6 deletions src/ragas/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ class SingleTurnMetric(Metric):
This class provides methods to score single-turn samples, both synchronously and asynchronously.
"""

def _only_required_columns(self, sample: SingleTurnSample) -> SingleTurnSample:
def _only_required_columns_single_turn(
self, sample: SingleTurnSample
) -> SingleTurnSample:
"""
Simplify the sample to only include the required columns.
"""
Expand All @@ -224,7 +226,7 @@ def single_turn_score(
"""
callbacks = callbacks or []
# only get the required columns
sample = self._only_required_columns(sample)
sample = self._only_required_columns_single_turn(sample)
rm, group_cm = new_group(
self.name,
inputs=sample.to_dict(),
Expand Down Expand Up @@ -267,7 +269,7 @@ async def single_turn_ascore(
"""
callbacks = callbacks or []
# only get the required columns
sample = self._only_required_columns(sample)
sample = self._only_required_columns_single_turn(sample)
rm, group_cm = new_group(
self.name,
inputs=sample.to_dict(),
Expand Down Expand Up @@ -308,7 +310,9 @@ class MultiTurnMetric(Metric):
for scoring multi-turn conversation samples.
"""

def _only_required_columns(self, sample: MultiTurnSample) -> MultiTurnSample:
def _only_required_columns_multi_turn(
self, sample: MultiTurnSample
) -> MultiTurnSample:
"""
Simplify the sample to only include the required columns.
"""
Expand All @@ -328,7 +332,7 @@ def multi_turn_score(
May raise ImportError if nest_asyncio is not installed in Jupyter-like environments.
"""
callbacks = callbacks or []
sample = self._only_required_columns(sample)
sample = self._only_required_columns_multi_turn(sample)
rm, group_cm = new_group(
self.name,
inputs=sample.to_dict(),
Expand Down Expand Up @@ -370,7 +374,7 @@ async def multi_turn_ascore(
May raise asyncio.TimeoutError if the scoring process exceeds the specified timeout.
"""
callbacks = callbacks or []
sample = self._only_required_columns(sample)
sample = self._only_required_columns_multi_turn(sample)

rm, group_cm = new_group(
self.name,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def _single_turn_ascore(self, sample: SingleTurnSample, callbacks):
"response",
}
assert (
fm._only_required_columns(
fm._only_required_columns_single_turn(
SingleTurnSample(user_input="a", response="b", reference="c")
).to_dict()
== SingleTurnSample(user_input="a", response="b").to_dict()
Expand Down

0 comments on commit cd359c3

Please sign in to comment.