Skip to content

Commit 3eebec5

Browse files
authored
Merge pull request #165 from Chainlit/willy/concurrency
fix: thread/step concurrency
2 parents cafe70d + 3ae371a commit 3eebec5

File tree

21 files changed

+206
-181
lines changed

21 files changed

+206
-181
lines changed

literalai/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from literalai.client import AsyncLiteralClient, LiteralClient
22
from literalai.evaluation.dataset import Dataset
3+
from literalai.evaluation.dataset_experiment import (
4+
DatasetExperiment,
5+
DatasetExperimentItem,
6+
)
37
from literalai.evaluation.dataset_item import DatasetItem
4-
from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem
5-
from literalai.prompt_engineering.prompt import Prompt
68
from literalai.my_types import * # noqa
79
from literalai.observability.generation import (
810
BaseGeneration,
@@ -13,6 +15,7 @@
1315
from literalai.observability.message import Message
1416
from literalai.observability.step import Attachment, Score, Step
1517
from literalai.observability.thread import Thread
18+
from literalai.prompt_engineering.prompt import Prompt
1619
from literalai.version import __version__
1720

1821
__all__ = [

literalai/api/asynchronous.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
11
import logging
22
import uuid
3+
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, cast
34

5+
import httpx
46
from typing_extensions import deprecated
5-
from typing import (
6-
Any,
7-
Callable,
8-
Dict,
9-
List,
10-
Literal,
11-
Optional,
12-
TypeVar,
13-
Union,
14-
cast,
15-
)
167

178
from literalai.api.base import BaseLiteralAPI, prepare_variables
18-
199
from literalai.api.helpers.attachment_helpers import (
2010
AttachmentUpload,
2111
create_attachment_helper,
@@ -91,6 +81,7 @@
9181
DatasetExperimentItem,
9282
)
9383
from literalai.evaluation.dataset_item import DatasetItem
84+
from literalai.my_types import PaginatedResponse, User
9485
from literalai.observability.filter import (
9586
generations_filters,
9687
generations_order_by,
@@ -102,12 +93,6 @@
10293
threads_order_by,
10394
users_filters,
10495
)
105-
from literalai.observability.thread import Thread
106-
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
107-
108-
import httpx
109-
110-
from literalai.my_types import PaginatedResponse, User
11196
from literalai.observability.generation import (
11297
BaseGeneration,
11398
ChatGeneration,
@@ -123,6 +108,8 @@
123108
StepDict,
124109
StepType,
125110
)
111+
from literalai.observability.thread import Thread
112+
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
126113

127114
logger = logging.getLogger(__name__)
128115

@@ -141,7 +128,11 @@ class AsyncLiteralAPI(BaseLiteralAPI):
141128
R = TypeVar("R")
142129

143130
async def make_gql_call(
144-
self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10
131+
self,
132+
description: str,
133+
query: str,
134+
variables: Dict[str, Any],
135+
timeout: Optional[int] = 10,
145136
) -> Dict:
146137
def raise_error(error):
147138
logger.error(f"Failed to {description}: {error}")
@@ -166,8 +157,7 @@ def raise_error(error):
166157
json = response.json()
167158
except ValueError as e:
168159
raise_error(
169-
f"""Failed to parse JSON response: {
170-
e}, content: {response.content!r}"""
160+
f"Failed to parse JSON response: {e}, content: {response.content!r}"
171161
)
172162

173163
if json.get("errors"):
@@ -178,8 +168,7 @@ def raise_error(error):
178168
for value in json["data"].values():
179169
if value and value.get("ok") is False:
180170
raise_error(
181-
f"""Failed to {description}: {
182-
value.get('message')}"""
171+
f"""Failed to {description}: {value.get("message")}"""
183172
)
184173
return json
185174

@@ -203,9 +192,9 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
203192
return response.json()
204193
except ValueError as e:
205194
raise ValueError(
206-
f"""Failed to parse JSON response: {
207-
e}, content: {response.content!r}"""
195+
f"Failed to parse JSON response: {e}, content: {response.content!r}"
208196
)
197+
209198
async def gql_helper(
210199
self,
211200
query: str,
@@ -235,7 +224,9 @@ async def get_user(
235224
) -> "User":
236225
return await self.gql_helper(*get_user_helper(id, identifier))
237226

238-
async def create_user(self, identifier: str, metadata: Optional[Dict] = None) -> "User":
227+
async def create_user(
228+
self, identifier: str, metadata: Optional[Dict] = None
229+
) -> "User":
239230
return await self.gql_helper(*create_user_helper(identifier, metadata))
240231

241232
async def update_user(
@@ -245,7 +236,7 @@ async def update_user(
245236

246237
async def delete_user(self, id: str) -> Dict:
247238
return await self.gql_helper(*delete_user_helper(id))
248-
239+
249240
async def get_or_create_user(
250241
self, identifier: str, metadata: Optional[Dict] = None
251242
) -> "User":
@@ -273,7 +264,7 @@ async def get_threads(
273264
first, after, before, filters, order_by, step_types_to_keep
274265
)
275266
)
276-
267+
277268
async def list_threads(
278269
self,
279270
first: Optional[int] = None,
@@ -491,7 +482,7 @@ async def create_attachment(
491482
thread_id = active_thread.id
492483

493484
if not step_id:
494-
if active_steps := active_steps_var.get([]):
485+
if active_steps := active_steps_var.get():
495486
step_id = active_steps[-1].id
496487
else:
497488
raise Exception("No step_id provided and no active step found.")
@@ -532,7 +523,9 @@ async def create_attachment(
532523
response = await self.make_gql_call(description, query, variables)
533524
return process_response(response)
534525

535-
async def update_attachment(self, id: str, update_params: AttachmentUpload) -> "Attachment":
526+
async def update_attachment(
527+
self, id: str, update_params: AttachmentUpload
528+
) -> "Attachment":
536529
return await self.gql_helper(*update_attachment_helper(id, update_params))
537530

538531
async def get_attachment(self, id: str) -> Optional["Attachment"]:
@@ -545,7 +538,6 @@ async def delete_attachment(self, id: str) -> Dict:
545538
# Step APIs #
546539
##################################################################################
547540

548-
549541
async def create_step(
550542
self,
551543
thread_id: Optional[str] = None,
@@ -646,7 +638,7 @@ async def get_generations(
646638
return await self.gql_helper(
647639
*get_generations_helper(first, after, before, filters, order_by)
648640
)
649-
641+
650642
async def create_generation(
651643
self, generation: Union["ChatGeneration", "CompletionGeneration"]
652644
) -> Union["ChatGeneration", "CompletionGeneration"]:
@@ -667,8 +659,10 @@ async def create_dataset(
667659
return await self.gql_helper(
668660
*create_dataset_helper(sync_api, name, description, metadata, type)
669661
)
670-
671-
async def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None) -> "Dataset":
662+
663+
async def get_dataset(
664+
self, id: Optional[str] = None, name: Optional[str] = None
665+
) -> "Dataset":
672666
sync_api = LiteralAPI(self.api_key, self.url)
673667
subpath, _, variables, process_response = get_dataset_helper(
674668
sync_api, id=id, name=name
@@ -738,7 +732,7 @@ async def create_experiment_item(
738732
result.scores = await self.create_scores(experiment_item.scores)
739733

740734
return result
741-
735+
742736
##################################################################################
743737
# DatasetItem APIs #
744738
##################################################################################
@@ -753,7 +747,7 @@ async def create_dataset_item(
753747
return await self.gql_helper(
754748
*create_dataset_item_helper(dataset_id, input, expected_output, metadata)
755749
)
756-
750+
757751
async def get_dataset_item(self, id: str) -> "DatasetItem":
758752
return await self.gql_helper(*get_dataset_item_helper(id))
759753

@@ -784,7 +778,9 @@ async def get_or_create_prompt_lineage(
784778
return await self.gql_helper(*create_prompt_lineage_helper(name, description))
785779

786780
@deprecated('Please use "get_or_create_prompt_lineage" instead.')
787-
async def create_prompt_lineage(self, name: str, description: Optional[str] = None) -> Dict:
781+
async def create_prompt_lineage(
782+
self, name: str, description: Optional[str] = None
783+
) -> Dict:
788784
return await self.get_or_create_prompt_lineage(name, description)
789785

790786
async def get_or_create_prompt(
@@ -838,7 +834,14 @@ async def get_prompt(
838834
raise ValueError("At least the `id` or the `name` must be provided.")
839835

840836
sync_api = LiteralAPI(self.api_key, self.url)
841-
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
837+
(
838+
get_prompt_query,
839+
description,
840+
variables,
841+
process_response,
842+
timeout,
843+
cached_prompt,
844+
) = get_prompt_helper(
842845
api=sync_api, id=id, name=name, version=version, cache=self.cache
843846
)
844847

literalai/api/base.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,16 @@
11
import os
2-
32
from abc import ABC, abstractmethod
4-
from typing import (
5-
Any,
6-
Dict,
7-
List,
8-
Optional,
9-
Union,
10-
)
3+
from typing import Any, Dict, List, Optional, Union
114

125
from typing_extensions import deprecated
136

14-
from literalai.my_types import Environment
15-
7+
from literalai.api.helpers.attachment_helpers import AttachmentUpload
8+
from literalai.api.helpers.prompt_helpers import PromptRollout
9+
from literalai.api.helpers.score_helpers import ScoreUpdate
1610
from literalai.cache.shared_cache import SharedCache
1711
from literalai.evaluation.dataset import DatasetType
18-
from literalai.evaluation.dataset_experiment import (
19-
DatasetExperimentItem,
20-
)
21-
from literalai.api.helpers.attachment_helpers import (
22-
AttachmentUpload)
23-
from literalai.api.helpers.score_helpers import (
24-
ScoreUpdate,
25-
)
26-
12+
from literalai.evaluation.dataset_experiment import DatasetExperimentItem
13+
from literalai.my_types import Environment
2714
from literalai.observability.filter import (
2815
generations_filters,
2916
generations_order_by,
@@ -35,24 +22,14 @@
3522
threads_order_by,
3623
users_filters,
3724
)
38-
from literalai.prompt_engineering.prompt import ProviderSettings
39-
40-
41-
from literalai.api.helpers.prompt_helpers import (
42-
PromptRollout)
43-
4425
from literalai.observability.generation import (
4526
ChatGeneration,
4627
CompletionGeneration,
4728
GenerationMessage,
4829
)
49-
from literalai.observability.step import (
50-
ScoreDict,
51-
ScoreType,
52-
Step,
53-
StepDict,
54-
StepType,
55-
)
30+
from literalai.observability.step import ScoreDict, ScoreType, Step, StepDict, StepType
31+
from literalai.prompt_engineering.prompt import ProviderSettings
32+
5633

5734
def prepare_variables(variables: Dict[str, Any]) -> Dict[str, Any]:
5835
"""
@@ -72,6 +49,7 @@ def handle_bytes(item):
7249

7350
return handle_bytes(variables)
7451

52+
7553
class BaseLiteralAPI(ABC):
7654
def __init__(
7755
self,
@@ -676,7 +654,7 @@ def delete_step(
676654
@abstractmethod
677655
def send_steps(self, steps: List[Union[StepDict, "Step"]]):
678656
"""
679-
Sends a list of steps to process.
657+
Sends a list of steps to process.
680658
Step ingestion happens asynchronously if you configured a cache. See [Cache Configuration](https://docs.literalai.com/self-hosting/deployment#4-cache-configuration-optional).
681659
682660
Args:
@@ -773,9 +751,7 @@ def create_dataset(
773751
pass
774752

775753
@abstractmethod
776-
def get_dataset(
777-
self, id: Optional[str] = None, name: Optional[str] = None
778-
):
754+
def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None):
779755
"""
780756
Retrieves a dataset by its ID or name.
781757
@@ -846,9 +822,7 @@ def create_experiment(
846822
pass
847823

848824
@abstractmethod
849-
def create_experiment_item(
850-
self, experiment_item: DatasetExperimentItem
851-
):
825+
def create_experiment_item(self, experiment_item: DatasetExperimentItem):
852826
"""
853827
Creates an experiment item within an existing experiment.
854828
@@ -1065,9 +1039,7 @@ def get_prompt_ab_testing(self, name: str):
10651039
pass
10661040

10671041
@abstractmethod
1068-
def update_prompt_ab_testing(
1069-
self, name: str, rollouts: List[PromptRollout]
1070-
):
1042+
def update_prompt_ab_testing(self, name: str, rollouts: List[PromptRollout]):
10711043
"""
10721044
Update the A/B testing configuration for a prompt lineage.
10731045

literalai/api/helpers/generation_helpers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Any, Dict, Optional, Union
22

3+
from literalai.api.helpers import gql
4+
from literalai.my_types import PaginatedResponse
35
from literalai.observability.filter import generations_filters, generations_order_by
4-
from literalai.my_types import (
5-
PaginatedResponse,
6+
from literalai.observability.generation import (
7+
BaseGeneration,
8+
ChatGeneration,
9+
CompletionGeneration,
610
)
7-
from literalai.observability.generation import BaseGeneration, CompletionGeneration, ChatGeneration
8-
9-
from literalai.api.helpers import gql
1011

1112

1213
def get_generations_helper(

0 commit comments

Comments
 (0)