Skip to content

Commit fe3ed68

Browse files
committed
Fix codestyle errors
1 parent 2e4d7c9 commit fe3ed68

File tree

6 files changed

+31
-19
lines changed

6 files changed

+31
-19
lines changed

src/ragas/embeddings/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def set_run_config(self, run_config: RunConfig):
4444

4545
class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
4646
def __init__(
47-
self, embeddings: Embeddings, run_config: t.Optional[RunConfig] = None
47+
self,
48+
embeddings: Embeddings,
49+
run_config: t.Optional[RunConfig] = None
4850
):
4951
self.embeddings = embeddings
5052
if run_config is None:

src/ragas/executor.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,18 @@ def runner_exception_hook(args: threading.ExceptHookArgs):
2525
# threading.excepthook = runner_exception_hook
2626

2727
def as_completed(loop, coros, max_workers):
28+
loop_arg_dict = {"loop": loop} if sys.version_info[:2] < (3, 10) else {}
2829
if max_workers == -1:
29-
return asyncio.as_completed(coros, loop=loop)
30+
return asyncio.as_completed(coros, **loop_arg_dict)
3031

3132
# loop argument is removed since Python 3.10
32-
semaphore = asyncio.Semaphore(
33-
max_workers,
34-
**({"loop": loop} if sys.version_info[:2] < (3, 10) else {})
35-
)
36-
async def sem_coro(coro):
33+
semaphore = asyncio.Semaphore(max_workers, **loop_arg_dict)
34+
async def sema_coro(coro):
3735
async with semaphore:
3836
return await coro
3937

40-
return asyncio.as_completed([sem_coro(c) for c in coros], loop=loop)
38+
sema_coros = [sema_coro(c) for c in coros]
39+
return asyncio.as_completed(sema_coros, **loop_arg_dict)
4140

4241
class Runner(threading.Thread):
4342
def __init__(
@@ -46,7 +45,7 @@ def __init__(
4645
desc: str,
4746
keep_progress_bar: bool = True,
4847
raise_exceptions: bool = True,
49-
run_config: RunConfig = None,
48+
run_config: t.Optional[RunConfig] = None
5049
):
5150
super().__init__()
5251
self.jobs = jobs
@@ -103,7 +102,7 @@ class Executor:
103102
keep_progress_bar: bool = True
104103
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
105104
raise_exceptions: bool = False
106-
run_config: RunConfig = None
105+
run_config: t.Optional[RunConfig] = field(default_factory=RunConfig, repr=False)
107106

108107
def wrap_callable_with_index(self, callable: t.Callable, counter):
109108
async def wrapped_callable_async(*args, **kwargs):

src/ragas/llms/base.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ class LangchainLLMWrapper(BaseRagasLLM):
119119
"""
120120

121121
def __init__(
122-
self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None
122+
self,
123+
langchain_llm: BaseLanguageModel,
124+
run_config: t.Optional[RunConfig] = None
123125
):
124126
self.langchain_llm = langchain_llm
125127
if run_config is None:
@@ -204,7 +206,8 @@ def set_run_config(self, run_config: RunConfig):
204206

205207

206208
def llm_factory(
207-
model: str = "gpt-3.5-turbo-16k", run_config: t.Optional[RunConfig] = None
209+
model: str = "gpt-3.5-turbo-16k",
210+
run_config: t.Optional[RunConfig] = None
208211
) -> BaseRagasLLM:
209212
timeout = None
210213
if run_config is not None:

src/ragas/testset/docstore.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class InMemoryDocumentStore(DocumentStore):
190190
nodes: t.List[Node] = field(default_factory=list)
191191
node_embeddings_list: t.List[Embedding] = field(default_factory=list)
192192
node_map: t.Dict[str, Node] = field(default_factory=dict)
193-
run_config: RunConfig = None
193+
run_config: t.Optional[RunConfig]
194194

195195
def _embed_items(self, items: t.Union[t.Sequence[Document], t.Sequence[Node]]):
196196
...

src/ragas/testset/evolutions.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def merge_nodes(nodes: CurrentNodes) -> Node:
8484
new_node.embedding = np.average(node_embeddings, axis=0)
8585
return new_node
8686

87-
def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
87+
def init(
88+
self,
89+
is_async: bool = True,
90+
run_config: t.Optional[RunConfig] = None
91+
):
8892
self.is_async = is_async
8993
if run_config is None:
9094
run_config = RunConfig()
@@ -323,7 +327,11 @@ class ComplexEvolution(Evolution):
323327
default_factory=lambda: compress_question_prompt
324328
)
325329

326-
def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
330+
def init(
331+
self,
332+
is_async: bool = True,
333+
run_config: t.Optional[RunConfig] = None
334+
):
327335
if run_config is None:
328336
run_config = RunConfig()
329337
super().init(is_async=is_async, run_config=run_config)

src/ragas/testset/generator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def with_openai(
7878
embeddings: str = "text-embedding-ada-002",
7979
docstore: t.Optional[DocumentStore] = None,
8080
chunk_size: int = 512,
81-
run_config: RunConfig = None,
81+
run_config: t.Optional[RunConfig] = None,
8282
) -> "TestsetGenerator":
8383
generator_llm_model = LangchainLLMWrapper(ChatOpenAI(model=generator_llm))
8484
critic_llm_model = LangchainLLMWrapper(ChatOpenAI(model=critic_llm))
@@ -120,7 +120,7 @@ def generate_with_llamaindex_docs(
120120
with_debugging_logs=False,
121121
is_async: bool = True,
122122
raise_exceptions: bool = True,
123-
run_config: t.Optional[RunConfig] = None,
123+
run_config: t.Optional[RunConfig] = None
124124
):
125125
# chunk documents and add to docstore
126126
self.docstore.add_documents(
@@ -146,7 +146,7 @@ def generate_with_langchain_docs(
146146
with_debugging_logs=False,
147147
is_async: bool = True,
148148
raise_exceptions: bool = True,
149-
run_config: t.Optional[RunConfig] = None,
149+
run_config: t.Optional[RunConfig] = None
150150
):
151151
# chunk documents and add to docstore
152152
self.docstore.add_documents(
@@ -184,7 +184,7 @@ def generate(
184184
with_debugging_logs=False,
185185
is_async: bool = True,
186186
raise_exceptions: bool = True,
187-
run_config: t.Optional[RunConfig] = None,
187+
run_config: t.Optional[RunConfig] = None
188188
):
189189
# validate distributions
190190
if not check_if_sum_is_close(list(distributions.values()), 1.0, 3):

0 commit comments

Comments
 (0)