Skip to content

Add support for optional max concurrency #643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def set_run_config(self, run_config: RunConfig):

class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
def __init__(
self, embeddings: Embeddings, run_config: t.Optional[RunConfig] = None
self,
embeddings: Embeddings,
run_config: t.Optional[RunConfig] = None
):
self.embeddings = embeddings
if run_config is None:
Expand Down
8 changes: 2 additions & 6 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def evaluate(
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
callbacks: Callbacks = [],
is_async: bool = False,
max_workers: t.Optional[int] = None,
run_config: t.Optional[RunConfig] = None,
raise_exceptions: bool = True,
column_map: t.Dict[str, str] = {},
Expand Down Expand Up @@ -77,9 +76,6 @@ def evaluate(
evaluation is run by calling the `metric.ascore` method. In case the llm or
embeddings does not support async then the evaluation can be run in sync mode
with `is_async=False`. Default is False.
max_workers: int, optional
The number of workers to use for the evaluation. This is used by the
`ThreadpoolExecutor` to run the evaluation in sync mode.
run_config: RunConfig, optional
Configuration for runtime settings like timeout and retries. If not provided,
default values are used.
Expand Down Expand Up @@ -128,8 +124,7 @@ def evaluate(
raise ValueError("Provide dataset!")

# default run_config
if run_config is None:
run_config = RunConfig()
run_config = run_config or RunConfig()
# default metrics
if metrics is None:
from ragas.metrics import (
Expand Down Expand Up @@ -184,6 +179,7 @@ def evaluate(
desc="Evaluating",
keep_progress_bar=True,
raise_exceptions=raise_exceptions,
run_config=run_config,
)
# new evaluation chain
row_run_managers = []
Expand Down
33 changes: 26 additions & 7 deletions src/ragas/executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import sys

import asyncio
import logging
Expand All @@ -10,6 +11,7 @@
from tqdm.auto import tqdm

from ragas.exceptions import MaxRetriesExceeded
from ragas.run_config import RunConfig

logger = logging.getLogger(__name__)

Expand All @@ -22,6 +24,19 @@ def runner_exception_hook(args: threading.ExceptHookArgs):
# set a custom exception hook
# threading.excepthook = runner_exception_hook

def as_completed(loop, coros, max_workers):
loop_arg_dict = {"loop": loop} if sys.version_info[:2] < (3, 10) else {}
if max_workers == -1:
return asyncio.as_completed(coros, **loop_arg_dict)

# loop argument is removed since Python 3.10
semaphore = asyncio.Semaphore(max_workers, **loop_arg_dict)
async def sema_coro(coro):
async with semaphore:
return await coro

sema_coros = [sema_coro(c) for c in coros]
return asyncio.as_completed(sema_coros, **loop_arg_dict)

class Runner(threading.Thread):
def __init__(
Expand All @@ -30,26 +45,29 @@ def __init__(
desc: str,
keep_progress_bar: bool = True,
raise_exceptions: bool = True,
run_config: t.Optional[RunConfig] = None
):
super().__init__()
self.jobs = jobs
self.desc = desc
self.keep_progress_bar = keep_progress_bar
self.raise_exceptions = raise_exceptions
self.futures = []
self.run_config = run_config or RunConfig()

# create task
self.loop = asyncio.new_event_loop()
for job in self.jobs:
coroutine, name = job
self.futures.append(self.loop.create_task(coroutine, name=name))
self.futures = as_completed(
loop=self.loop,
coros=[coro for coro, _ in self.jobs],
max_workers=self.run_config.max_workers
)

async def _aresults(self) -> t.List[t.Any]:
results = []
for future in tqdm(
asyncio.as_completed(self.futures),
self.futures,
desc=self.desc,
total=len(self.futures),
total=len(self.jobs),
# whether you want to keep the progress bar after completion
leave=self.keep_progress_bar,
):
Expand All @@ -75,7 +93,6 @@ def run(self):
results = self.loop.run_until_complete(self._aresults())
finally:
self.results = results
[f.cancel() for f in self.futures]
self.loop.stop()


Expand All @@ -85,6 +102,7 @@ class Executor:
keep_progress_bar: bool = True
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
raise_exceptions: bool = False
run_config: t.Optional[RunConfig] = field(default_factory=RunConfig, repr=False)

def wrap_callable_with_index(self, callable: t.Callable, counter):
async def wrapped_callable_async(*args, **kwargs):
Expand All @@ -104,6 +122,7 @@ def results(self) -> t.List[t.Any]:
desc=self.desc,
keep_progress_bar=self.keep_progress_bar,
raise_exceptions=self.raise_exceptions,
run_config=self.run_config,
)
executor_job.start()
try:
Expand Down
7 changes: 5 additions & 2 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ class LangchainLLMWrapper(BaseRagasLLM):
"""

def __init__(
self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None
self,
langchain_llm: BaseLanguageModel,
run_config: t.Optional[RunConfig] = None
):
self.langchain_llm = langchain_llm
if run_config is None:
Expand Down Expand Up @@ -204,7 +206,8 @@ def set_run_config(self, run_config: RunConfig):


def llm_factory(
model: str = "gpt-3.5-turbo-16k", run_config: t.Optional[RunConfig] = None
model: str = "gpt-3.5-turbo-16k",
run_config: t.Optional[RunConfig] = None
) -> BaseRagasLLM:
timeout = None
if run_config is not None:
Expand Down
1 change: 1 addition & 0 deletions src/ragas/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RunConfig:
timeout: int = 60
max_retries: int = 10
max_wait: int = 60
max_workers: int = 16
exception_types: t.Union[
t.Type[BaseException],
t.Tuple[t.Type[BaseException], ...],
Expand Down
9 changes: 5 additions & 4 deletions src/ragas/testset/docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Direction(str, Enum):
PREV = "prev"
UP = "up"
DOWN = "down"


class Node(Document):
keyphrases: t.List[str] = Field(default_factory=list, repr=False)
Expand Down Expand Up @@ -196,6 +196,7 @@ class InMemoryDocumentStore(DocumentStore):
nodes: t.List[Node] = field(default_factory=list)
node_embeddings_list: t.List[Embedding] = field(default_factory=list)
node_map: t.Dict[str, Node] = field(default_factory=dict)
run_config: t.Optional[RunConfig] = None

def _embed_items(self, items: t.Union[t.Sequence[Document], t.Sequence[Node]]):
...
Expand All @@ -213,9 +214,7 @@ def add_documents(self, docs: t.Sequence[Document], show_progress=True):
]
self.add_nodes(nodes, show_progress=show_progress)

def add_nodes(
self, nodes: t.Sequence[Node], show_progress=True, desc: str = "embedding nodes"
):
def add_nodes(self, nodes: t.Sequence[Node], show_progress=True):
assert self.embeddings is not None, "Embeddings must be set"
assert self.extractor is not None, "Extractor must be set"

Expand All @@ -228,6 +227,7 @@ def add_nodes(
desc="embedding nodes",
keep_progress_bar=False,
raise_exceptions=True,
run_config=self.run_config,
)
result_idx = 0
for i, n in enumerate(nodes):
Expand Down Expand Up @@ -356,3 +356,4 @@ def get_similar(
def set_run_config(self, run_config: RunConfig):
if self.embeddings:
self.embeddings.set_run_config(run_config)
self.run_config = run_config
12 changes: 10 additions & 2 deletions src/ragas/testset/evolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def merge_nodes(nodes: CurrentNodes) -> Node:
new_node.embedding = np.average(node_embeddings, axis=0)
return new_node

def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
def init(
self,
is_async: bool = True,
run_config: t.Optional[RunConfig] = None
):
self.is_async = is_async
if run_config is None:
run_config = RunConfig()
Expand Down Expand Up @@ -331,7 +335,11 @@ class ComplexEvolution(Evolution):
default_factory=lambda: compress_question_prompt
)

def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
def init(
self,
is_async: bool = True,
run_config: t.Optional[RunConfig] = None
):
if run_config is None:
run_config = RunConfig()
super().init(is_async=is_async, run_config=run_config)
Expand Down
9 changes: 6 additions & 3 deletions src/ragas/testset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def with_openai(
critic_llm: str = "gpt-4",
embeddings: str = "text-embedding-ada-002",
docstore: t.Optional[DocumentStore] = None,
run_config: t.Optional[RunConfig] = None,
chunk_size: int = 1024,
) -> "TestsetGenerator":
generator_llm_model = LangchainLLMWrapper(ChatOpenAI(model=generator_llm))
Expand All @@ -93,6 +94,7 @@ def with_openai(
splitter=splitter,
embeddings=embeddings_model,
extractor=keyphrase_extractor,
run_config=run_config,
)
return cls(
generator_llm=generator_llm_model,
Expand All @@ -118,7 +120,7 @@ def generate_with_llamaindex_docs(
with_debugging_logs=False,
is_async: bool = True,
raise_exceptions: bool = True,
run_config: t.Optional[RunConfig] = None,
run_config: t.Optional[RunConfig] = None
):
# chunk documents and add to docstore
self.docstore.add_documents(
Expand All @@ -144,7 +146,7 @@ def generate_with_langchain_docs(
with_debugging_logs=False,
is_async: bool = True,
raise_exceptions: bool = True,
run_config: t.Optional[RunConfig] = None,
run_config: t.Optional[RunConfig] = None
):
# chunk documents and add to docstore
self.docstore.add_documents(
Expand Down Expand Up @@ -182,7 +184,7 @@ def generate(
with_debugging_logs=False,
is_async: bool = True,
raise_exceptions: bool = True,
run_config: t.Optional[RunConfig] = None,
run_config: t.Optional[RunConfig] = None
):
# validate distributions
if not check_if_sum_is_close(list(distributions.values()), 1.0, 3):
Expand Down Expand Up @@ -213,6 +215,7 @@ def generate(
desc="Generating",
keep_progress_bar=True,
raise_exceptions=raise_exceptions,
run_config=run_config,
)

current_nodes = [
Expand Down