Skip to content

Commit 2e4d7c9

Browse files
committed
Migrated arg-based max_concurrency to RunConfig.
Also renamed max_concurrency to max_workers to be consistent with convention.
1 parent a09d6c2 commit 2e4d7c9

File tree

5 files changed

+31
-62
lines changed

5 files changed

+31
-62
lines changed

src/ragas/evaluation.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper, embedding_factory
1414
from ragas.llms import llm_factory
1515
from ragas.exceptions import ExceptionInRunner
16-
from ragas.executor import Executor, DEFAULT_MAX_CONCURRENCY
16+
from ragas.executor import Executor
1717
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
1818
from ragas.metrics._answer_correctness import AnswerCorrectness
1919
from ragas.metrics.base import Metric, MetricWithEmbeddings, MetricWithLLM
@@ -40,7 +40,6 @@ def evaluate(
4040
embeddings: t.Optional[BaseRagasEmbeddings] = None,
4141
callbacks: Callbacks = [],
4242
is_async: bool = False,
43-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY,
4443
run_config: t.Optional[RunConfig] = None,
4544
raise_exceptions: bool = True,
4645
column_map: t.Dict[str, str] = {},
@@ -73,9 +72,6 @@ def evaluate(
7372
evaluation is run by calling the `metric.ascore` method. In case the llm or
7473
embeddings does not support async then the evaluation can be run in sync mode
7574
with `is_async=False`. Default is False.
76-
max_concurrency: int, optional
77-
The number of workers to use for the evaluation. This is used by the
78-
`ThreadpoolExecutor` to run the evaluation in sync mode.
7975
run_config: RunConfig, optional
8076
Configuration for runtime settings like timeout and retries. If not provided,
8177
default values are used.
@@ -124,8 +120,7 @@ def evaluate(
124120
raise ValueError("Provide dataset!")
125121

126122
# default run_config
127-
if run_config is None:
128-
run_config = RunConfig()
123+
run_config = run_config or RunConfig()
129124
# default metrics
130125
if metrics is None:
131126
from ragas.metrics import (
@@ -180,7 +175,7 @@ def evaluate(
180175
desc="Evaluating",
181176
keep_progress_bar=True,
182177
raise_exceptions=raise_exceptions,
183-
max_concurrency=max_concurrency,
178+
run_config=run_config,
184179
)
185180
# new evaluation chain
186181
row_run_managers = []

src/ragas/executor.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from tqdm.auto import tqdm
1212

1313
from ragas.exceptions import MaxRetriesExceeded
14+
from ragas.run_config import RunConfig
1415

1516
logger = logging.getLogger(__name__)
1617

1718

18-
DEFAULT_MAX_CONCURRENCY = 16
19-
2019
def runner_exception_hook(args: threading.ExceptHookArgs):
2120
print(args)
2221
raise args.exc_type
@@ -25,13 +24,13 @@ def runner_exception_hook(args: threading.ExceptHookArgs):
2524
# set a custom exception hook
2625
# threading.excepthook = runner_exception_hook
2726

28-
def as_completed(loop, coros, max_concurrency):
29-
if max_concurrency == -1:
27+
def as_completed(loop, coros, max_workers):
28+
if max_workers == -1:
3029
return asyncio.as_completed(coros, loop=loop)
3130

3231
# loop argument is removed since Python 3.10
3332
semaphore = asyncio.Semaphore(
34-
max_concurrency,
33+
max_workers,
3534
**({"loop": loop} if sys.version_info[:2] < (3, 10) else {})
3635
)
3736
async def sem_coro(coro):
@@ -47,25 +46,25 @@ def __init__(
4746
desc: str,
4847
keep_progress_bar: bool = True,
4948
raise_exceptions: bool = True,
50-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY,
49+
run_config: RunConfig = None,
5150
):
5251
super().__init__()
5352
self.jobs = jobs
5453
self.desc = desc
5554
self.keep_progress_bar = keep_progress_bar
5655
self.raise_exceptions = raise_exceptions
57-
self.max_concurrency = max_concurrency
56+
self.run_config = run_config or RunConfig()
5857

5958
# create task
6059
self.loop = asyncio.new_event_loop()
6160
self.futures = as_completed(
62-
self.loop,
63-
[coro for coro, _ in self.jobs],
64-
self.max_concurrency)
61+
loop=self.loop,
62+
coros=[coro for coro, _ in self.jobs],
63+
max_workers=self.run_config.max_workers
64+
)
6565

6666
async def _aresults(self) -> t.List[t.Any]:
6767
results = []
68-
6968
for future in tqdm(
7069
self.futures,
7170
desc=self.desc,
@@ -95,7 +94,6 @@ def run(self):
9594
results = self.loop.run_until_complete(self._aresults())
9695
finally:
9796
self.results = results
98-
[f.cancel() for f in self.futures]
9997
self.loop.stop()
10098

10199

@@ -105,7 +103,7 @@ class Executor:
105103
keep_progress_bar: bool = True
106104
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
107105
raise_exceptions: bool = False
108-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY
106+
run_config: RunConfig = None
109107

110108
def wrap_callable_with_index(self, callable: t.Callable, counter):
111109
async def wrapped_callable_async(*args, **kwargs):
@@ -125,7 +123,7 @@ def results(self) -> t.List[t.Any]:
125123
desc=self.desc,
126124
keep_progress_bar=self.keep_progress_bar,
127125
raise_exceptions=self.raise_exceptions,
128-
max_concurrency=self.max_concurrency,
126+
run_config=self.run_config,
129127
)
130128
executor_job.start()
131129
try:

src/ragas/run_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class RunConfig:
2020
timeout: int = 60
2121
max_retries: int = 10
2222
max_wait: int = 60
23+
max_workers: int = 16
2324
exception_types: t.Union[
2425
t.Type[BaseException],
2526
t.Tuple[t.Type[BaseException], ...],

src/ragas/testset/docstore.py

+9-29
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ragas.embeddings.base import BaseRagasEmbeddings
1818
from ragas.exceptions import ExceptionInRunner
19-
from ragas.executor import Executor, DEFAULT_MAX_CONCURRENCY
19+
from ragas.executor import Executor
2020
from ragas.run_config import RunConfig
2121
from ragas.testset.utils import rng
2222

@@ -83,27 +83,16 @@ class Direction(str, Enum):
8383
UP = "up"
8484
DOWN = "down"
8585

86-
8786
class DocumentStore(ABC):
8887
def __init__(self):
8988
self.documents = {}
9089

9190
@abstractmethod
92-
def add_documents(
93-
self,
94-
docs: t.Sequence[Document],
95-
show_progress=True,
96-
max_concurrency: t.Optional[int]=DEFAULT_MAX_CONCURRENCY
97-
):
91+
def add_documents(self, docs: t.Sequence[Document], show_progress=True):
9892
...
9993

10094
@abstractmethod
101-
def add_nodes(
102-
self,
103-
nodes: t.Sequence[Node],
104-
show_progress=True,
105-
max_concurrency: t.Optional[int]=DEFAULT_MAX_CONCURRENCY
106-
):
95+
def add_nodes(self, nodes: t.Sequence[Node], show_progress=True):
10796
...
10897

10998
@abstractmethod
@@ -201,16 +190,12 @@ class InMemoryDocumentStore(DocumentStore):
201190
nodes: t.List[Node] = field(default_factory=list)
202191
node_embeddings_list: t.List[Embedding] = field(default_factory=list)
203192
node_map: t.Dict[str, Node] = field(default_factory=dict)
193+
run_config: RunConfig = None
204194

205195
def _embed_items(self, items: t.Union[t.Sequence[Document], t.Sequence[Node]]):
206196
...
207197

208-
def add_documents(
209-
self,
210-
docs: t.Sequence[Document],
211-
show_progress=True,
212-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY
213-
):
198+
def add_documents(self, docs: t.Sequence[Document], show_progress=True):
214199
"""
215200
Add documents in batch mode.
216201
"""
@@ -222,15 +207,9 @@ def add_documents(
222207
for d in self.splitter.transform_documents(docs)
223208
]
224209

225-
self.add_nodes(nodes, show_progress=show_progress, max_concurrency=max_concurrency)
210+
self.add_nodes(nodes, show_progress=show_progress)
226211

227-
def add_nodes(
228-
self,
229-
nodes: t.Sequence[Node],
230-
show_progress=True,
231-
desc: str = "embedding nodes",
232-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY,
233-
):
212+
def add_nodes(self, nodes: t.Sequence[Node], show_progress=True):
234213
assert self.embeddings is not None, "Embeddings must be set"
235214
assert self.extractor is not None, "Extractor must be set"
236215

@@ -243,7 +222,7 @@ def add_nodes(
243222
desc="embedding nodes",
244223
keep_progress_bar=False,
245224
raise_exceptions=True,
246-
max_concurrency=max_concurrency,
225+
run_config=self.run_config,
247226
)
248227
result_idx = 0
249228
for i, n in enumerate(nodes):
@@ -341,3 +320,4 @@ def get_adjacent(
341320
def set_run_config(self, run_config: RunConfig):
342321
if self.embeddings:
343322
self.embeddings.set_run_config(run_config)
323+
self.run_config = run_config

src/ragas/testset/generator.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ragas._analytics import TesetGenerationEvent, track
1414
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
1515
from ragas.exceptions import ExceptionInRunner
16-
from ragas.executor import Executor, DEFAULT_MAX_CONCURRENCY
16+
from ragas.executor import Executor
1717
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper
1818
from ragas.run_config import RunConfig
1919
from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore
@@ -78,6 +78,7 @@ def with_openai(
7878
embeddings: str = "text-embedding-ada-002",
7979
docstore: t.Optional[DocumentStore] = None,
8080
chunk_size: int = 512,
81+
run_config: RunConfig = None,
8182
) -> "TestsetGenerator":
8283
generator_llm_model = LangchainLLMWrapper(ChatOpenAI(model=generator_llm))
8384
critic_llm_model = LangchainLLMWrapper(ChatOpenAI(model=critic_llm))
@@ -93,6 +94,7 @@ def with_openai(
9394
splitter=splitter,
9495
embeddings=embeddings_model,
9596
extractor=keyphrase_extractor,
97+
run_config=run_config,
9698
)
9799
return cls(
98100
generator_llm=generator_llm_model,
@@ -119,12 +121,10 @@ def generate_with_llamaindex_docs(
119121
is_async: bool = True,
120122
raise_exceptions: bool = True,
121123
run_config: t.Optional[RunConfig] = None,
122-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY,
123124
):
124125
# chunk documents and add to docstore
125126
self.docstore.add_documents(
126-
[Document.from_llamaindex_document(doc) for doc in documents],
127-
max_concurrency=max_concurrency,
127+
[Document.from_llamaindex_document(doc) for doc in documents]
128128
)
129129

130130
return self.generate(
@@ -134,7 +134,6 @@ def generate_with_llamaindex_docs(
134134
is_async=is_async,
135135
run_config=run_config,
136136
raise_exceptions=raise_exceptions,
137-
max_concurrency=max_concurrency,
138137
)
139138

140139
# if you add any arguments to this function, make sure to add them to
@@ -148,12 +147,10 @@ def generate_with_langchain_docs(
148147
is_async: bool = True,
149148
raise_exceptions: bool = True,
150149
run_config: t.Optional[RunConfig] = None,
151-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY,
152150
):
153151
# chunk documents and add to docstore
154152
self.docstore.add_documents(
155-
[Document.from_langchain_document(doc) for doc in documents],
156-
max_concurrency=max_concurrency,
153+
[Document.from_langchain_document(doc) for doc in documents]
157154
)
158155

159156
return self.generate(
@@ -163,7 +160,6 @@ def generate_with_langchain_docs(
163160
is_async=is_async,
164161
raise_exceptions=raise_exceptions,
165162
run_config=run_config,
166-
max_concurrency=max_concurrency,
167163
)
168164

169165
def init_evolution(self, evolution: Evolution) -> None:
@@ -189,7 +185,6 @@ def generate(
189185
is_async: bool = True,
190186
raise_exceptions: bool = True,
191187
run_config: t.Optional[RunConfig] = None,
192-
max_concurrency: t.Optional[int] = DEFAULT_MAX_CONCURRENCY,
193188
):
194189
# validate distributions
195190
if not check_if_sum_is_close(list(distributions.values()), 1.0, 3):
@@ -220,7 +215,7 @@ def generate(
220215
desc="Generating",
221216
keep_progress_bar=True,
222217
raise_exceptions=raise_exceptions,
223-
max_concurrency=max_concurrency,
218+
run_config=run_config,
224219
)
225220

226221
current_nodes = [

0 commit comments

Comments
 (0)