Skip to content

Commit 834566f

Browse files
committed
Add optional max concurrency support to Runner
1 parent c18c7f4 commit 834566f

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

src/ragas/evaluation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper, embedding_factory
1414
from ragas.llms import llm_factory
1515
from ragas.exceptions import ExceptionInRunner
16-
from ragas.executor import Executor
16+
from ragas.executor import Executor, DEFAULT_MAX_CONCURRENCY
1717
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
1818
from ragas.metrics._answer_correctness import AnswerCorrectness
1919
from ragas.metrics.base import Metric, MetricWithEmbeddings, MetricWithLLM
@@ -44,6 +44,7 @@ def evaluate(
4444
run_config: t.Optional[RunConfig] = None,
4545
raise_exceptions: bool = True,
4646
column_map: t.Dict[str, str] = {},
47+
max_concurrency: int = DEFAULT_MAX_CONCURRENCY,
4748
) -> Result:
4849
"""
4950
Run the evaluation on the dataset with different metrics
@@ -180,6 +181,7 @@ def evaluate(
180181
desc="Evaluating",
181182
keep_progress_bar=True,
182183
raise_exceptions=raise_exceptions,
184+
max_concurrency=max_concurrency,
183185
)
184186
# new evaluation chain
185187
row_run_managers = []

src/ragas/executor.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
logger = logging.getLogger(__name__)
1515

1616

17+
DEFAULT_MAX_CONCURRENCY = 16
18+
1719
def runner_exception_hook(args: threading.ExceptHookArgs):
1820
print(args)
1921
raise args.exc_type
@@ -22,6 +24,16 @@ def runner_exception_hook(args: threading.ExceptHookArgs):
2224
# set a custom exception hook
2325
# threading.excepthook = runner_exception_hook
2426

27+
def as_completed(loop, coros, max_concurrency):
28+
if max_concurrency == -1:
29+
return asyncio.as_completed(coros, loop=loop)
30+
31+
semaphore = asyncio.Semaphore(max_concurrency, loop=loop)
32+
async def sem_coro(coro):
33+
async with semaphore:
34+
return await coro
35+
36+
return asyncio.as_completed([sem_coro(c) for c in coros], loop=loop)
2537

2638
class Runner(threading.Thread):
2739
def __init__(
@@ -30,26 +42,29 @@ def __init__(
3042
desc: str,
3143
keep_progress_bar: bool = True,
3244
raise_exceptions: bool = True,
45+
max_concurrency: int = DEFAULT_MAX_CONCURRENCY,
3346
):
3447
super().__init__()
3548
self.jobs = jobs
3649
self.desc = desc
3750
self.keep_progress_bar = keep_progress_bar
3851
self.raise_exceptions = raise_exceptions
39-
self.futures = []
52+
self.max_concurrency = max_concurrency
4053

4154
# create task
4255
self.loop = asyncio.new_event_loop()
43-
for job in self.jobs:
44-
coroutine, name = job
45-
self.futures.append(self.loop.create_task(coroutine, name=name))
56+
self.futures = as_completed(
57+
self.loop,
58+
[coro for coro, _ in self.jobs],
59+
self.max_concurrency)
4660

4761
async def _aresults(self) -> t.List[t.Any]:
4862
results = []
63+
4964
for future in tqdm(
50-
asyncio.as_completed(self.futures),
65+
self.futures,
5166
desc=self.desc,
52-
total=len(self.futures),
67+
total=len(self.jobs),
5368
# whether you want to keep the progress bar after completion
5469
leave=self.keep_progress_bar,
5570
):
@@ -85,6 +100,7 @@ class Executor:
85100
keep_progress_bar: bool = True
86101
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
87102
raise_exceptions: bool = False
103+
max_concurrency: int = DEFAULT_MAX_CONCURRENCY
88104

89105
def wrap_callable_with_index(self, callable: t.Callable, counter):
90106
async def wrapped_callable_async(*args, **kwargs):
@@ -104,6 +120,7 @@ def results(self) -> t.List[t.Any]:
104120
desc=self.desc,
105121
keep_progress_bar=self.keep_progress_bar,
106122
raise_exceptions=self.raise_exceptions,
123+
max_concurrency=self.max_concurrency,
107124
)
108125
executor_job.start()
109126
try:

src/ragas/testset/docstore.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ragas.embeddings.base import BaseRagasEmbeddings
1818
from ragas.exceptions import ExceptionInRunner
19-
from ragas.executor import Executor
19+
from ragas.executor import Executor, DEFAULT_MAX_CONCURRENCY
2020
from ragas.run_config import RunConfig
2121
from ragas.testset.utils import rng
2222

@@ -210,7 +210,11 @@ def add_documents(self, docs: t.Sequence[Document], show_progress=True):
210210
self.add_nodes(nodes, show_progress=show_progress)
211211

212212
def add_nodes(
213-
self, nodes: t.Sequence[Node], show_progress=True, desc: str = "embedding nodes"
213+
self,
214+
nodes: t.Sequence[Node],
215+
show_progress=True,
216+
desc: str = "embedding nodes",
217+
max_concurrency: int = DEFAULT_MAX_CONCURRENCY,
214218
):
215219
assert self.embeddings is not None, "Embeddings must be set"
216220
assert self.extractor is not None, "Extractor must be set"
@@ -224,6 +228,7 @@ def add_nodes(
224228
desc="embedding nodes",
225229
keep_progress_bar=False,
226230
raise_exceptions=True,
231+
max_concurrency=max_concurrency,
227232
)
228233
result_idx = 0
229234
for i, n in enumerate(nodes):

src/ragas/testset/generator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ragas._analytics import TesetGenerationEvent, track
1414
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
1515
from ragas.exceptions import ExceptionInRunner
16-
from ragas.executor import Executor
16+
from ragas.executor import Executor, DEFAULT_MAX_CONCURRENCY
1717
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper
1818
from ragas.run_config import RunConfig
1919
from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore
@@ -183,6 +183,7 @@ def generate(
183183
is_async: bool = True,
184184
raise_exceptions: bool = True,
185185
run_config: t.Optional[RunConfig] = None,
186+
max_concurrency: int = DEFAULT_MAX_CONCURRENCY,
186187
):
187188
# validate distributions
188189
if not check_if_sum_is_close(list(distributions.values()), 1.0, 3):
@@ -213,6 +214,7 @@ def generate(
213214
desc="Generating",
214215
keep_progress_bar=True,
215216
raise_exceptions=raise_exceptions,
217+
max_concurrency=max_concurrency,
216218
)
217219

218220
current_nodes = [

0 commit comments

Comments
 (0)