Skip to content

Commit 12ad190

Browse files
joy13975jjmachan
andauthored
Add support for optional max concurrency (#643)
**Added optional Semaphore-based concurrency control for #642** As for the default value for `max_concurrency`, I don't know the ratio of API users vs. local LLM users, so the proposed default is an opinionated value of `16` * I *think* more people use OpenAI API for now vs. local LLMs, thus default is not `-1` (no limit) * `16` seems to be reasonably fast and doesn't seem to hit throughput limit in my experience **Tests** Embedding for 1k documents finished in <2min and subsequent Testset generation for `test_size=1000` proceeding without getting stuck: <img width="693" alt="image" src="https://github.com/explodinggradients/ragas/assets/6729737/d83fecc8-a815-43ee-a3b0-3395d7a9d244"> another 30s passes: <img width="725" alt="image" src="https://github.com/explodinggradients/ragas/assets/6729737/d4ab08ba-5a79-45f6-84b1-e563f107d682"> --------- Co-authored-by: Jithin James <[email protected]>
1 parent 366cb9f commit 12ad190

File tree

8 files changed

+58
-25
lines changed

8 files changed

+58
-25
lines changed

src/ragas/embeddings/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def set_run_config(self, run_config: RunConfig):
4444

4545
class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
4646
def __init__(
47-
self, embeddings: Embeddings, run_config: t.Optional[RunConfig] = None
47+
self,
48+
embeddings: Embeddings,
49+
run_config: t.Optional[RunConfig] = None
4850
):
4951
self.embeddings = embeddings
5052
if run_config is None:

src/ragas/evaluation.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def evaluate(
4444
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
4545
callbacks: Callbacks = [],
4646
is_async: bool = False,
47-
max_workers: t.Optional[int] = None,
4847
run_config: t.Optional[RunConfig] = None,
4948
raise_exceptions: bool = True,
5049
column_map: t.Dict[str, str] = {},
@@ -77,9 +76,6 @@ def evaluate(
7776
evaluation is run by calling the `metric.ascore` method. In case the llm or
7877
embeddings does not support async then the evaluation can be run in sync mode
7978
with `is_async=False`. Default is False.
80-
max_workers: int, optional
81-
The number of workers to use for the evaluation. This is used by the
82-
`ThreadpoolExecutor` to run the evaluation in sync mode.
8379
run_config: RunConfig, optional
8480
Configuration for runtime settings like timeout and retries. If not provided,
8581
default values are used.
@@ -128,8 +124,7 @@ def evaluate(
128124
raise ValueError("Provide dataset!")
129125

130126
# default run_config
131-
if run_config is None:
132-
run_config = RunConfig()
127+
run_config = run_config or RunConfig()
133128
# default metrics
134129
if metrics is None:
135130
from ragas.metrics import (
@@ -184,6 +179,7 @@ def evaluate(
184179
desc="Evaluating",
185180
keep_progress_bar=True,
186181
raise_exceptions=raise_exceptions,
182+
run_config=run_config,
187183
)
188184
# new evaluation chain
189185
row_run_managers = []

src/ragas/executor.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import sys
23

34
import asyncio
45
import logging
@@ -10,6 +11,7 @@
1011
from tqdm.auto import tqdm
1112

1213
from ragas.exceptions import MaxRetriesExceeded
14+
from ragas.run_config import RunConfig
1315

1416
logger = logging.getLogger(__name__)
1517

@@ -22,6 +24,19 @@ def runner_exception_hook(args: threading.ExceptHookArgs):
2224
# set a custom exception hook
2325
# threading.excepthook = runner_exception_hook
2426

27+
def as_completed(loop, coros, max_workers):
28+
loop_arg_dict = {"loop": loop} if sys.version_info[:2] < (3, 10) else {}
29+
if max_workers == -1:
30+
return asyncio.as_completed(coros, **loop_arg_dict)
31+
32+
# loop argument is removed since Python 3.10
33+
semaphore = asyncio.Semaphore(max_workers, **loop_arg_dict)
34+
async def sema_coro(coro):
35+
async with semaphore:
36+
return await coro
37+
38+
sema_coros = [sema_coro(c) for c in coros]
39+
return asyncio.as_completed(sema_coros, **loop_arg_dict)
2540

2641
class Runner(threading.Thread):
2742
def __init__(
@@ -30,26 +45,29 @@ def __init__(
3045
desc: str,
3146
keep_progress_bar: bool = True,
3247
raise_exceptions: bool = True,
48+
run_config: t.Optional[RunConfig] = None
3349
):
3450
super().__init__()
3551
self.jobs = jobs
3652
self.desc = desc
3753
self.keep_progress_bar = keep_progress_bar
3854
self.raise_exceptions = raise_exceptions
39-
self.futures = []
55+
self.run_config = run_config or RunConfig()
4056

4157
# create task
4258
self.loop = asyncio.new_event_loop()
43-
for job in self.jobs:
44-
coroutine, name = job
45-
self.futures.append(self.loop.create_task(coroutine, name=name))
59+
self.futures = as_completed(
60+
loop=self.loop,
61+
coros=[coro for coro, _ in self.jobs],
62+
max_workers=self.run_config.max_workers
63+
)
4664

4765
async def _aresults(self) -> t.List[t.Any]:
4866
results = []
4967
for future in tqdm(
50-
asyncio.as_completed(self.futures),
68+
self.futures,
5169
desc=self.desc,
52-
total=len(self.futures),
70+
total=len(self.jobs),
5371
# whether you want to keep the progress bar after completion
5472
leave=self.keep_progress_bar,
5573
):
@@ -75,7 +93,6 @@ def run(self):
7593
results = self.loop.run_until_complete(self._aresults())
7694
finally:
7795
self.results = results
78-
[f.cancel() for f in self.futures]
7996
self.loop.stop()
8097

8198

@@ -85,6 +102,7 @@ class Executor:
85102
keep_progress_bar: bool = True
86103
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
87104
raise_exceptions: bool = False
105+
run_config: t.Optional[RunConfig] = field(default_factory=RunConfig, repr=False)
88106

89107
def wrap_callable_with_index(self, callable: t.Callable, counter):
90108
async def wrapped_callable_async(*args, **kwargs):
@@ -104,6 +122,7 @@ def results(self) -> t.List[t.Any]:
104122
desc=self.desc,
105123
keep_progress_bar=self.keep_progress_bar,
106124
raise_exceptions=self.raise_exceptions,
125+
run_config=self.run_config,
107126
)
108127
executor_job.start()
109128
try:

src/ragas/llms/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ class LangchainLLMWrapper(BaseRagasLLM):
119119
"""
120120

121121
def __init__(
122-
self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None
122+
self,
123+
langchain_llm: BaseLanguageModel,
124+
run_config: t.Optional[RunConfig] = None
123125
):
124126
self.langchain_llm = langchain_llm
125127
if run_config is None:
@@ -204,7 +206,8 @@ def set_run_config(self, run_config: RunConfig):
204206

205207

206208
def llm_factory(
207-
model: str = "gpt-3.5-turbo-16k", run_config: t.Optional[RunConfig] = None
209+
model: str = "gpt-3.5-turbo-16k",
210+
run_config: t.Optional[RunConfig] = None
208211
) -> BaseRagasLLM:
209212
timeout = None
210213
if run_config is not None:

src/ragas/run_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class RunConfig:
2020
timeout: int = 60
2121
max_retries: int = 10
2222
max_wait: int = 60
23+
max_workers: int = 16
2324
exception_types: t.Union[
2425
t.Type[BaseException],
2526
t.Tuple[t.Type[BaseException], ...],

src/ragas/testset/docstore.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class Direction(str, Enum):
7878
PREV = "prev"
7979
UP = "up"
8080
DOWN = "down"
81-
81+
8282

8383
class Node(Document):
8484
keyphrases: t.List[str] = Field(default_factory=list, repr=False)
@@ -196,6 +196,7 @@ class InMemoryDocumentStore(DocumentStore):
196196
nodes: t.List[Node] = field(default_factory=list)
197197
node_embeddings_list: t.List[Embedding] = field(default_factory=list)
198198
node_map: t.Dict[str, Node] = field(default_factory=dict)
199+
run_config: t.Optional[RunConfig] = None
199200

200201
def _embed_items(self, items: t.Union[t.Sequence[Document], t.Sequence[Node]]):
201202
...
@@ -213,9 +214,7 @@ def add_documents(self, docs: t.Sequence[Document], show_progress=True):
213214
]
214215
self.add_nodes(nodes, show_progress=show_progress)
215216

216-
def add_nodes(
217-
self, nodes: t.Sequence[Node], show_progress=True, desc: str = "embedding nodes"
218-
):
217+
def add_nodes(self, nodes: t.Sequence[Node], show_progress=True):
219218
assert self.embeddings is not None, "Embeddings must be set"
220219
assert self.extractor is not None, "Extractor must be set"
221220

@@ -228,6 +227,7 @@ def add_nodes(
228227
desc="embedding nodes",
229228
keep_progress_bar=False,
230229
raise_exceptions=True,
230+
run_config=self.run_config,
231231
)
232232
result_idx = 0
233233
for i, n in enumerate(nodes):
@@ -356,3 +356,4 @@ def get_similar(
356356
def set_run_config(self, run_config: RunConfig):
357357
if self.embeddings:
358358
self.embeddings.set_run_config(run_config)
359+
self.run_config = run_config

src/ragas/testset/evolutions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ def merge_nodes(nodes: CurrentNodes) -> Node:
8888
new_node.embedding = np.average(node_embeddings, axis=0)
8989
return new_node
9090

91-
def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
91+
def init(
92+
self,
93+
is_async: bool = True,
94+
run_config: t.Optional[RunConfig] = None
95+
):
9296
self.is_async = is_async
9397
if run_config is None:
9498
run_config = RunConfig()
@@ -335,7 +339,11 @@ class ComplexEvolution(Evolution):
335339
default_factory=lambda: compress_question_prompt
336340
)
337341

338-
def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
342+
def init(
343+
self,
344+
is_async: bool = True,
345+
run_config: t.Optional[RunConfig] = None
346+
):
339347
if run_config is None:
340348
run_config = RunConfig()
341349
super().init(is_async=is_async, run_config=run_config)

src/ragas/testset/generator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def with_openai(
7777
critic_llm: str = "gpt-4",
7878
embeddings: str = "text-embedding-ada-002",
7979
docstore: t.Optional[DocumentStore] = None,
80+
run_config: t.Optional[RunConfig] = None,
8081
chunk_size: int = 1024,
8182
) -> "TestsetGenerator":
8283
generator_llm_model = LangchainLLMWrapper(ChatOpenAI(model=generator_llm))
@@ -93,6 +94,7 @@ def with_openai(
9394
splitter=splitter,
9495
embeddings=embeddings_model,
9596
extractor=keyphrase_extractor,
97+
run_config=run_config,
9698
)
9799
return cls(
98100
generator_llm=generator_llm_model,
@@ -118,7 +120,7 @@ def generate_with_llamaindex_docs(
118120
with_debugging_logs=False,
119121
is_async: bool = True,
120122
raise_exceptions: bool = True,
121-
run_config: t.Optional[RunConfig] = None,
123+
run_config: t.Optional[RunConfig] = None
122124
):
123125
# chunk documents and add to docstore
124126
self.docstore.add_documents(
@@ -144,7 +146,7 @@ def generate_with_langchain_docs(
144146
with_debugging_logs=False,
145147
is_async: bool = True,
146148
raise_exceptions: bool = True,
147-
run_config: t.Optional[RunConfig] = None,
149+
run_config: t.Optional[RunConfig] = None
148150
):
149151
# chunk documents and add to docstore
150152
self.docstore.add_documents(
@@ -182,7 +184,7 @@ def generate(
182184
with_debugging_logs=False,
183185
is_async: bool = True,
184186
raise_exceptions: bool = True,
185-
run_config: t.Optional[RunConfig] = None,
187+
run_config: t.Optional[RunConfig] = None
186188
):
187189
# validate distributions
188190
if not check_if_sum_is_close(list(distributions.values()), 1.0, 3):
@@ -213,6 +215,7 @@ def generate(
213215
desc="Generating",
214216
keep_progress_bar=True,
215217
raise_exceptions=raise_exceptions,
218+
run_config=run_config,
216219
)
217220

218221
current_nodes = [

0 commit comments

Comments
 (0)