Skip to content

Commit 059a528

Browse files
committed
merged with main
1 parent 2ef0513 commit 059a528

17 files changed

+311
-77
lines changed

.github/workflows/ci.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ jobs:
9696
fi
9797
# Now run the unit tests
9898
pytest tests/unit "${OPTS[@]}"
99+
env:
100+
__RAGAS_DEBUG_TRACKING: true
101+
RAGAS_DO_NOT_TRACK: true
99102

100103
codestyle_check:
101104
runs-on: ubuntu-latest

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies = [
99
"openai>1",
1010
"pysbd>=0.3.4",
1111
"nest-asyncio",
12+
"appdirs",
1213
]
1314
dynamic = ["version", "readme"]
1415

src/ragas/_analytics.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
import json
34
import logging
45
import os
56
import typing as t
6-
from dataclasses import asdict, dataclass
7+
import uuid
78
from functools import lru_cache, wraps
89

910
import requests
11+
from appdirs import user_data_dir
12+
from langchain_core.pydantic_v1 import BaseModel, Field
1013

1114
from ragas.utils import get_debug_mode
1215

@@ -19,9 +22,11 @@
1922

2023

2124
USAGE_TRACKING_URL = "https://t.explodinggradients.com"
25+
USAGE_REQUESTS_TIMEOUT_SEC = 1
26+
USER_DATA_DIR_NAME = "ragas"
27+
# Any chance you chance this also change the variable in our ci.yaml file
2228
RAGAS_DO_NOT_TRACK = "RAGAS_DO_NOT_TRACK"
2329
RAGAS_DEBUG_TRACKING = "__RAGAS_DEBUG_TRACKING"
24-
USAGE_REQUESTS_TIMEOUT_SEC = 1
2530

2631

2732
@lru_cache(maxsize=1)
@@ -33,7 +38,7 @@ def do_not_track() -> bool: # pragma: no cover
3338

3439
@lru_cache(maxsize=1)
3540
def _usage_event_debugging() -> bool:
36-
# For BentoML developers only - debug and print event payload if turned on
41+
# For Ragas developers only - debug and print event payload if turned on
3742
return os.environ.get(RAGAS_DEBUG_TRACKING, str(False)).lower() == "true"
3843

3944

@@ -49,6 +54,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
4954
logger.error(
5055
"Tracking Error: %s", err, stack_info=True, stacklevel=3
5156
)
57+
raise err
5258
else:
5359
logger.info("Tracking Error: %s", err)
5460
else:
@@ -57,14 +63,28 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
5763
return wrapper
5864

5965

60-
@dataclass
61-
class BaseEvent:
66+
@lru_cache(maxsize=1)
67+
@silent
68+
def get_userid() -> str:
69+
user_id_path = user_data_dir(appname=USER_DATA_DIR_NAME)
70+
uuid_filepath = os.path.join(user_id_path, "uuid.json")
71+
if os.path.exists(uuid_filepath):
72+
user_id = json.load(open(uuid_filepath))["userid"]
73+
else:
74+
user_id = "a-" + uuid.uuid4().hex
75+
os.makedirs(user_id_path)
76+
with open(uuid_filepath, "w") as f:
77+
json.dump({"userid": user_id}, f)
78+
return user_id
79+
80+
81+
class BaseEvent(BaseModel):
6282
event_type: str
83+
user_id: str = Field(default_factory=get_userid)
6384

6485

65-
@dataclass
6686
class EvaluationEvent(BaseEvent):
67-
metrics: list[str]
87+
metrics: t.List[str]
6888
evaluation_mode: str
6989
num_rows: int
7090

@@ -74,8 +94,7 @@ def track(event_properties: BaseEvent):
7494
if do_not_track():
7595
return
7696

77-
payload = asdict(event_properties)
78-
97+
payload = dict(event_properties)
7998
if _usage_event_debugging():
8099
# For internal debugging purpose
81100
logger.info("Tracking Payload: %s", payload)

src/ragas/evaluation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ def evaluate(
144144
row_run_managers.append((row_rm, row_group_cm))
145145

146146
if is_async:
147-
[executor.submit(metric.ascore, row, row_group_cm) for metric in metrics]
147+
[
148+
executor.submit(
149+
metric.ascore, row, row_group_cm, name=f"{metric.name}-{i}"
150+
)
151+
for metric in metrics
152+
]
148153
else:
149154
[executor.submit(metric.score, row, row_group_cm) for metric in metrics]
150155

src/ragas/executor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ async def wrapped_callable_async(*args, **kwargs):
4343
else:
4444
return wrapped_callable
4545

46-
def submit(self, callable: t.Callable, *args, **kwargs):
46+
def submit(
47+
self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs
48+
):
4749
if self.is_async:
4850
self.executor = t.cast(asyncio.AbstractEventLoop, self.executor)
4951
callable_with_index = self.wrap_callable_with_index(
@@ -52,7 +54,9 @@ def submit(self, callable: t.Callable, *args, **kwargs):
5254
# is type correct?
5355
callable_with_index = t.cast(t.Callable, callable_with_index)
5456
self.futures.append(
55-
self.executor.create_task(callable_with_index(*args, **kwargs))
57+
self.executor.create_task(
58+
callable_with_index(*args, **kwargs), name=name
59+
)
5660
)
5761
else:
5862
self.executor = t.cast(ThreadPoolExecutor, self.executor)

src/ragas/llms/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from langchain_core.language_models import BaseLanguageModel
1010
from langchain_core.outputs import LLMResult
1111

12-
from ragas.llms.prompt import PromptValue
13-
1412
if t.TYPE_CHECKING:
1513
from langchain_core.callbacks import Callbacks
1614
from langchain_core.prompts import ChatPromptTemplate
1715

16+
from ragas.llms.prompt import PromptValue
17+
18+
1819
MULTIPLE_COMPLETION_SUPPORTED = [
1920
OpenAI,
2021
ChatOpenAI,
@@ -66,6 +67,8 @@ def generate_text_with_hmpt(
6667
stop: t.Optional[t.List[str]] = None,
6768
callbacks: Callbacks = [],
6869
) -> LLMResult:
70+
from ragas.llms.prompt import PromptValue
71+
6972
prompt = PromptValue(prompt_str=prompts[0].format())
7073
return self.generate_text(prompt, n, temperature, stop, callbacks)
7174

src/ragas/llms/json_load.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,38 @@ def safe_load(self, text: str, llm: BaseRagasLLM, callbacks: Callbacks = None):
8282
start, end = self._find_outermost_json(text)
8383
return json.loads(text[start:end])
8484
except ValueError:
85-
text = self._fix_to_json(text, llm, callbacks)
85+
from ragas.llms.prompt import PromptValue
86+
87+
results = llm.generate_text(
88+
PromptValue(prompt_str=JSON_PROMPT.format(input=text)),
89+
n=1,
90+
callbacks=callbacks,
91+
)
92+
text = results.generations[0][0].text
8693
retry += 1
8794

8895
return {}
8996

90-
def _fix_to_json(self, text: str, llm: BaseRagasLLM, callbacks: Callbacks):
91-
from ragas.llms.prompt import PromptValue
97+
async def asafe_load(
98+
self, text: str, llm: BaseRagasLLM, callbacks: Callbacks = None
99+
):
100+
retry = 0
101+
while retry <= self.max_retries:
102+
try:
103+
start, end = self._find_outermost_json(text)
104+
return json.loads(text[start:end])
105+
except ValueError:
106+
from ragas.llms.prompt import PromptValue
107+
108+
results = await llm.agenerate_text(
109+
PromptValue(prompt_str=JSON_PROMPT.format(input=text)),
110+
n=1,
111+
callbacks=callbacks,
112+
)
113+
text = results.generations[0][0].text
114+
retry += 1
92115

93-
# TODO (executor)
94-
results = llm.generate_text(
95-
PromptValue(prompt_str=JSON_PROMPT.format(input=text)),
96-
n=1,
97-
callbacks=callbacks,
98-
)
99-
return results.generations[0][0].text
116+
return {}
100117

101118
def _find_outermost_json(self, text):
102119
stack = []

src/ragas/llms/prompt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
from langchain_core.prompt_values import PromptValue as BasePromptValue
1010
from langchain_core.pydantic_v1 import BaseModel, root_validator
1111

12+
from ragas.llms import BaseRagasLLM
1213
from ragas.llms.json_load import json_loader
1314
from ragas.utils import get_cache_dir
1415

15-
if t.TYPE_CHECKING:
16-
from ragas.llms import BaseRagasLLM
17-
1816
Example = t.Dict[str, t.Any]
1917

2018

src/ragas/metrics/_answer_correctness.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
if t.TYPE_CHECKING:
1717
from langchain_core.callbacks import Callbacks
18-
from langchain_core.outputs import LLMResult
1918

2019
CORRECTNESS_PROMPT = Prompt(
2120
name="answer_correctness",
@@ -110,7 +109,7 @@ def __post_init__(self: t.Self):
110109
llm=self.llm, batch_size=self.batch_size
111110
)
112111

113-
def _compute_statement_presence(self, result: LLMResult) -> float:
112+
def _compute_statement_presence(self, prediction: t.Any) -> float:
114113
assert self.llm is not None, "LLM must be set"
115114

116115
key_map = {
@@ -120,7 +119,6 @@ def _compute_statement_presence(self, result: LLMResult) -> float:
120119
}
121120
outputs = result.generations[0]
122121

123-
prediction = json_loader.safe_load(outputs[0].text, self.llm)
124122
prediction = prediction if isinstance(prediction, list) else [prediction]
125123
if prediction:
126124
prediction = [
@@ -146,7 +144,10 @@ def _score(self, row: t.Dict, callbacks: Callbacks) -> float:
146144
p_value = self.correctness_prompt.format(question=q, ground_truth=g, answer=a)
147145
is_statement_present = self.llm.generate_text(p_value, callbacks=callbacks)
148146

149-
f1_score = self._compute_statement_presence(is_statement_present)
147+
prediction = json_loader.safe_load(
148+
is_statement_present.generations[0][0].text, self.llm
149+
)
150+
f1_score = self._compute_statement_presence(prediction)
150151

151152
if self.weights[1] == 0:
152153
similarity_score = 0
@@ -169,7 +170,10 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
169170
p_value, callbacks=callbacks
170171
)
171172

172-
f1_score = self._compute_statement_presence(is_statement_present)
173+
prediction = await json_loader.asafe_load(
174+
is_statement_present.generations[0][0].text, self.llm
175+
)
176+
f1_score = self._compute_statement_presence(prediction)
173177

174178
if self.weights[1] == 0:
175179
similarity_score = 0

src/ragas/metrics/_answer_relevance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
from dataclasses import dataclass, field
66

77
import numpy as np
8-
from langchain.embeddings import OpenAIEmbeddings
98

109
from ragas.embeddings.base import embedding_factory
11-
from ragas.exceptions import OpenAIKeyNotFound
1210
from ragas.llms.json_load import json_loader
1311
from ragas.llms.prompt import Prompt
1412
from ragas.metrics.base import EvaluationMode, MetricWithLLM
@@ -83,10 +81,6 @@ class AnswerRelevancy(MetricWithLLM):
8381
def init_model(self):
8482
super().init_model()
8583

86-
if isinstance(self.embeddings, OpenAIEmbeddings):
87-
if self.embeddings.openai_api_key == "no-key":
88-
raise OpenAIKeyNotFound
89-
9084
def calculate_similarity(
9185
self: t.Self, question: str, generated_questions: list[str]
9286
):
@@ -143,7 +137,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
143137
callbacks=callbacks,
144138
)
145139
response = [
146-
json_loader.safe_load(r.text, self.llm) for r in result.generations[0]
140+
await json_loader.asafe_load(r.text, self.llm)
141+
for r in result.generations[0]
147142
]
148143

149144
return self._calculate_score(response, row)

src/ragas/metrics/_context_precision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ async def _ascore(
143143
)
144144
responses.append(result.generations[0][0].text)
145145

146-
json_responses = [json_loader.safe_load(item, self.llm) for item in responses]
146+
json_responses = [
147+
await json_loader.asafe_load(item, self.llm) for item in responses
148+
]
147149
score = self._calculate_average_precision(json_responses)
148150
return score
149151

src/ragas/metrics/_context_recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
123123
result = await self.llm.agenerate_text(
124124
self._create_context_recall_prompt(row), callbacks=callbacks
125125
)
126-
response = json_loader.safe_load(result.generations[0][0].text, self.llm)
126+
response = await json_loader.asafe_load(result.generations[0][0].text, self.llm)
127127

128128
return self._compute_score(response)
129129

0 commit comments

Comments
 (0)