1
1
import logging
2
2
import uuid
3
+ from typing import Any , Callable , Dict , List , Literal , Optional , TypeVar , Union , cast
3
4
5
+ import httpx
4
6
from typing_extensions import deprecated
5
- from typing import (
6
- Any ,
7
- Callable ,
8
- Dict ,
9
- List ,
10
- Literal ,
11
- Optional ,
12
- TypeVar ,
13
- Union ,
14
- cast ,
15
- )
16
7
17
8
from literalai .api .base import BaseLiteralAPI , prepare_variables
18
-
19
9
from literalai .api .helpers .attachment_helpers import (
20
10
AttachmentUpload ,
21
11
create_attachment_helper ,
91
81
DatasetExperimentItem ,
92
82
)
93
83
from literalai .evaluation .dataset_item import DatasetItem
84
+ from literalai .my_types import PaginatedResponse , User
94
85
from literalai .observability .filter import (
95
86
generations_filters ,
96
87
generations_order_by ,
102
93
threads_order_by ,
103
94
users_filters ,
104
95
)
105
- from literalai .observability .thread import Thread
106
- from literalai .prompt_engineering .prompt import Prompt , ProviderSettings
107
-
108
- import httpx
109
-
110
- from literalai .my_types import PaginatedResponse , User
111
96
from literalai .observability .generation import (
112
97
BaseGeneration ,
113
98
ChatGeneration ,
123
108
StepDict ,
124
109
StepType ,
125
110
)
111
+ from literalai .observability .thread import Thread
112
+ from literalai .prompt_engineering .prompt import Prompt , ProviderSettings
126
113
127
114
logger = logging .getLogger (__name__ )
128
115
@@ -141,7 +128,11 @@ class AsyncLiteralAPI(BaseLiteralAPI):
141
128
R = TypeVar ("R" )
142
129
143
130
async def make_gql_call (
144
- self , description : str , query : str , variables : Dict [str , Any ], timeout : Optional [int ] = 10
131
+ self ,
132
+ description : str ,
133
+ query : str ,
134
+ variables : Dict [str , Any ],
135
+ timeout : Optional [int ] = 10 ,
145
136
) -> Dict :
146
137
def raise_error (error ):
147
138
logger .error (f"Failed to { description } : { error } " )
@@ -166,8 +157,7 @@ def raise_error(error):
166
157
json = response .json ()
167
158
except ValueError as e :
168
159
raise_error (
169
- f"""Failed to parse JSON response: {
170
- e } , content: { response .content !r} """
160
+ f"Failed to parse JSON response: { e } , content: { response .content !r} "
171
161
)
172
162
173
163
if json .get ("errors" ):
@@ -178,8 +168,7 @@ def raise_error(error):
178
168
for value in json ["data" ].values ():
179
169
if value and value .get ("ok" ) is False :
180
170
raise_error (
181
- f"""Failed to { description } : {
182
- value .get ('message' )} """
171
+ f"""Failed to { description } : { value .get ("message" )} """
183
172
)
184
173
return json
185
174
@@ -203,9 +192,9 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
203
192
return response .json ()
204
193
except ValueError as e :
205
194
raise ValueError (
206
- f"""Failed to parse JSON response: {
207
- e } , content: { response .content !r} """
195
+ f"Failed to parse JSON response: { e } , content: { response .content !r} "
208
196
)
197
+
209
198
async def gql_helper (
210
199
self ,
211
200
query : str ,
@@ -235,7 +224,9 @@ async def get_user(
235
224
) -> "User" :
236
225
return await self .gql_helper (* get_user_helper (id , identifier ))
237
226
238
- async def create_user (self , identifier : str , metadata : Optional [Dict ] = None ) -> "User" :
227
+ async def create_user (
228
+ self , identifier : str , metadata : Optional [Dict ] = None
229
+ ) -> "User" :
239
230
return await self .gql_helper (* create_user_helper (identifier , metadata ))
240
231
241
232
async def update_user (
@@ -245,7 +236,7 @@ async def update_user(
245
236
246
237
async def delete_user (self , id : str ) -> Dict :
247
238
return await self .gql_helper (* delete_user_helper (id ))
248
-
239
+
249
240
async def get_or_create_user (
250
241
self , identifier : str , metadata : Optional [Dict ] = None
251
242
) -> "User" :
@@ -273,7 +264,7 @@ async def get_threads(
273
264
first , after , before , filters , order_by , step_types_to_keep
274
265
)
275
266
)
276
-
267
+
277
268
async def list_threads (
278
269
self ,
279
270
first : Optional [int ] = None ,
@@ -491,7 +482,7 @@ async def create_attachment(
491
482
thread_id = active_thread .id
492
483
493
484
if not step_id :
494
- if active_steps := active_steps_var .get ([] ):
485
+ if active_steps := active_steps_var .get ():
495
486
step_id = active_steps [- 1 ].id
496
487
else :
497
488
raise Exception ("No step_id provided and no active step found." )
@@ -532,7 +523,9 @@ async def create_attachment(
532
523
response = await self .make_gql_call (description , query , variables )
533
524
return process_response (response )
534
525
535
- async def update_attachment (self , id : str , update_params : AttachmentUpload ) -> "Attachment" :
526
+ async def update_attachment (
527
+ self , id : str , update_params : AttachmentUpload
528
+ ) -> "Attachment" :
536
529
return await self .gql_helper (* update_attachment_helper (id , update_params ))
537
530
538
531
async def get_attachment (self , id : str ) -> Optional ["Attachment" ]:
@@ -545,7 +538,6 @@ async def delete_attachment(self, id: str) -> Dict:
545
538
# Step APIs #
546
539
##################################################################################
547
540
548
-
549
541
async def create_step (
550
542
self ,
551
543
thread_id : Optional [str ] = None ,
@@ -646,7 +638,7 @@ async def get_generations(
646
638
return await self .gql_helper (
647
639
* get_generations_helper (first , after , before , filters , order_by )
648
640
)
649
-
641
+
650
642
async def create_generation (
651
643
self , generation : Union ["ChatGeneration" , "CompletionGeneration" ]
652
644
) -> Union ["ChatGeneration" , "CompletionGeneration" ]:
@@ -667,8 +659,10 @@ async def create_dataset(
667
659
return await self .gql_helper (
668
660
* create_dataset_helper (sync_api , name , description , metadata , type )
669
661
)
670
-
671
- async def get_dataset (self , id : Optional [str ] = None , name : Optional [str ] = None ) -> "Dataset" :
662
+
663
+ async def get_dataset (
664
+ self , id : Optional [str ] = None , name : Optional [str ] = None
665
+ ) -> "Dataset" :
672
666
sync_api = LiteralAPI (self .api_key , self .url )
673
667
subpath , _ , variables , process_response = get_dataset_helper (
674
668
sync_api , id = id , name = name
@@ -738,7 +732,7 @@ async def create_experiment_item(
738
732
result .scores = await self .create_scores (experiment_item .scores )
739
733
740
734
return result
741
-
735
+
742
736
##################################################################################
743
737
# DatasetItem APIs #
744
738
##################################################################################
@@ -753,7 +747,7 @@ async def create_dataset_item(
753
747
return await self .gql_helper (
754
748
* create_dataset_item_helper (dataset_id , input , expected_output , metadata )
755
749
)
756
-
750
+
757
751
async def get_dataset_item (self , id : str ) -> "DatasetItem" :
758
752
return await self .gql_helper (* get_dataset_item_helper (id ))
759
753
@@ -784,7 +778,9 @@ async def get_or_create_prompt_lineage(
784
778
return await self .gql_helper (* create_prompt_lineage_helper (name , description ))
785
779
786
780
@deprecated ('Please use "get_or_create_prompt_lineage" instead.' )
787
- async def create_prompt_lineage (self , name : str , description : Optional [str ] = None ) -> Dict :
781
+ async def create_prompt_lineage (
782
+ self , name : str , description : Optional [str ] = None
783
+ ) -> Dict :
788
784
return await self .get_or_create_prompt_lineage (name , description )
789
785
790
786
async def get_or_create_prompt (
@@ -838,7 +834,14 @@ async def get_prompt(
838
834
raise ValueError ("At least the `id` or the `name` must be provided." )
839
835
840
836
sync_api = LiteralAPI (self .api_key , self .url )
841
- get_prompt_query , description , variables , process_response , timeout , cached_prompt = get_prompt_helper (
837
+ (
838
+ get_prompt_query ,
839
+ description ,
840
+ variables ,
841
+ process_response ,
842
+ timeout ,
843
+ cached_prompt ,
844
+ ) = get_prompt_helper (
842
845
api = sync_api , id = id , name = name , version = version , cache = self .cache
843
846
)
844
847
0 commit comments