Skip to content

Commit 0517b1f

Browse files
authored
Merge pull request #1051 from parea-ai/PAI-1464-make-trace-decorator-work-with-iterator-responses
Pai 1464 make trace decorator work with iterator responses
2 parents 8151f6a + 9d46118 commit 0517b1f

File tree

7 files changed

+130
-58
lines changed

7 files changed

+130
-58
lines changed
Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,81 @@
11
import os
22

3+
import anthropic
34
import instructor
45
from dotenv import load_dotenv
56
from openai import AsyncOpenAI
7+
from pydantic import BaseModel
68

7-
from parea import Parea
9+
from parea import Parea, trace
810

911
load_dotenv()
1012

11-
client = AsyncOpenAI()
13+
oai_aclient = AsyncOpenAI()
14+
ant_client = anthropic.AsyncClient()
1215

1316
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
14-
p.wrap_openai_client(client, "instructor")
15-
16-
client = instructor.from_openai(client)
1717

18+
p.wrap_openai_client(oai_aclient, "instructor")
19+
p.wrap_anthropic_client(ant_client)
1820

19-
from pydantic import BaseModel
21+
oai_aclient = instructor.from_openai(oai_aclient)
22+
ant_client = instructor.from_anthropic(ant_client)
2023

2124

2225
class UserDetail(BaseModel):
2326
name: str
24-
age: int
27+
age: str
2528

2629

27-
async def main():
28-
user = client.completions.create_partial(
29-
model="gpt-3.5-turbo",
30+
@trace
31+
async def ainner_main():
32+
user = oai_aclient.completions.create_partial(
33+
model="gpt-4o-mini",
3034
max_tokens=1024,
3135
max_retries=3,
3236
messages=[
3337
{
3438
"role": "user",
35-
"content": "Please crea a user",
39+
"content": "Please create a user",
3640
}
3741
],
3842
response_model=UserDetail,
3943
)
40-
# print(user)
41-
async for u in user:
44+
return user
45+
46+
47+
async def amain():
48+
resp = await ainner_main()
49+
async for u in resp:
4250
print(u)
4351

44-
# user2 = client.completions.create_partial(
45-
# model="gpt-3.5-turbo",
46-
# max_tokens=1024,
47-
# max_retries=3,
48-
# messages=[
49-
# {
50-
# "role": "user",
51-
# "content": "Please crea a user",
52-
# }
53-
# ],
54-
# response_model=UserDetail,
55-
# )
56-
# async for u in user2:
57-
# print(u)
52+
53+
@trace
54+
def inner_main():
55+
user = ant_client.completions.create_partial(
56+
model="claude-3-5-sonnet-20240620",
57+
max_tokens=1024,
58+
max_retries=3,
59+
messages=[
60+
{
61+
"role": "user",
62+
"content": "Please create a user",
63+
}
64+
],
65+
response_model=UserDetail,
66+
)
67+
return user
68+
69+
70+
def main():
71+
resp = inner_main()
72+
for u in resp:
73+
print(u)
5874

5975

6076
if __name__ == "__main__":
6177
import asyncio
6278

63-
asyncio.run(main())
79+
asyncio.run(amain())
80+
81+
main()

cookbook/openai/tracing_open_ai_streams.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@
1818

1919

2020
@trace
21-
def call_openai_stream(data: dict):
21+
def _call_openai_stream(data: dict):
2222
data["stream"] = True
2323
stream = client.chat.completions.create(**data)
24+
for chunk in stream:
25+
yield chunk
26+
27+
28+
def call_openai_stream(data: dict):
29+
stream = _call_openai_stream(data)
2430
for chunk in stream:
2531
print(chunk.choices[0].delta or "")
2632

parea/utils/trace_integrations/instructor.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Any, Callable, Mapping, Tuple
1+
from typing import Any, Callable, List, Mapping, Tuple
22

33
import contextvars
4+
import logging
45
from json import JSONDecodeError
56

67
from instructor.retry import InstructorRetryException
@@ -12,6 +13,9 @@
1213
from parea.schemas import EvaluationResult, UpdateLog
1314
from parea.utils.trace_integrations.wrapt_utils import CopyableFunctionWrapper
1415
from parea.utils.trace_utils import logger_update_record, trace_data, trace_insert
16+
from parea.utils.universal_encoder import json_dumps
17+
18+
logger = logging.getLogger()
1519

1620
instructor_trace_id = contextvars.ContextVar("instructor_trace_id", default="")
1721
instructor_val_err_count = contextvars.ContextVar("instructor_val_err_count", default=0)
@@ -50,14 +54,11 @@ def report_instructor_validation_errors() -> None:
5054
score=instructor_val_err_count.get(),
5155
reason=reason,
5256
)
53-
last_child_trace_id = trace_data.get()[instructor_trace_id.get()].children[-1]
54-
trace_insert(
55-
{
56-
"scores": [instructor_score],
57-
"configuration": trace_data.get()[last_child_trace_id].configuration,
58-
},
59-
instructor_trace_id.get(),
60-
)
57+
trace_update_dict = {"scores": [instructor_score]}
58+
if children := trace_data.get()[instructor_trace_id.get()].children:
59+
last_child_trace_id = children[-1]
60+
trace_update_dict["configuration"] = trace_data.get()[last_child_trace_id].configuration
61+
trace_insert(trace_update_dict, instructor_trace_id.get())
6162
instructor_trace_id.set("")
6263
instructor_val_err_count.set(0)
6364
instructor_val_errs.set([])
@@ -82,12 +83,15 @@ def __call__(
8283
trace_name = "instructor"
8384
if "response_model" in kwargs and kwargs["response_model"] and hasattr(kwargs["response_model"], "__name__"):
8485
trace_name = kwargs["response_model"].__name__
85-
return trace(
86-
name=trace_name,
87-
overwrite_trace_id=trace_id,
88-
overwrite_inputs=inputs,
89-
metadata=metadata,
90-
)(
86+
87+
def fn_transform_generator_outputs(items: List) -> str:
88+
try:
89+
return json_dumps(items[-1])
90+
except Exception as e:
91+
logger.warning(f"Failed to serialize generator output: {e}", exc_info=e)
92+
return ""
93+
94+
return trace(name=trace_name, overwrite_trace_id=trace_id, overwrite_inputs=inputs, metadata=metadata, fn_transform_generator_outputs=fn_transform_generator_outputs)(
9195
wrapped
9296
)(*args, **kwargs)
9397
except (InstructorRetryException, ValidationError, JSONDecodeError) as e:

parea/utils/trace_utils.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def trace(
166166
overwrite_trace_id: Optional[str] = None,
167167
overwrite_inputs: Optional[Dict[str, Any]] = None,
168168
log_sample_rate: Optional[float] = 1.0,
169+
fn_transform_generator_outputs: Callable[[List[Any]], str] = None,
169170
):
170171
def init_trace(func_name, _parea_target_field, args, kwargs, func) -> Tuple[str, datetime, contextvars.Token]:
171172
start_time = timezone_aware_now()
@@ -258,24 +259,60 @@ def cleanup_trace(trace_id: str, start_time: datetime, context_token: contextvar
258259
thread_eval_funcs_then_log(trace_id, eval_funcs)
259260
trace_context.reset(context_token)
260261

262+
def _handle_iterator_cleanup(items, trace_id, start_time, context_token):
263+
if fn_transform_generator_outputs:
264+
result = fn_transform_generator_outputs(items)
265+
elif all(isinstance(item, str) for item in items):
266+
result = "".join(items)
267+
else:
268+
result = ""
269+
if not is_logging_disabled() and not log_omit_outputs:
270+
fill_trace_data(trace_id, {"result": result}, UpdateTraceScenario.RESULT)
271+
272+
cleanup_trace(trace_id, start_time, context_token)
273+
274+
async def _wrap_async_iterator(iterator, trace_id, start_time, context_token):
275+
items = []
276+
try:
277+
async for item in iterator:
278+
items.append(item)
279+
yield item
280+
finally:
281+
_handle_iterator_cleanup(items, trace_id, start_time, context_token)
282+
283+
def _wrap_sync_iterator(iterator, trace_id, start_time, context_token):
284+
items = []
285+
try:
286+
for item in iterator:
287+
items.append(item)
288+
yield item
289+
finally:
290+
_handle_iterator_cleanup(items, trace_id, start_time, context_token)
291+
261292
def decorator(func):
262293
@wraps(func)
263294
async def async_wrapper(*args, **kwargs):
264295
_parea_target_field = kwargs.pop("_parea_target_field", None)
265296
trace_id, start_time, context_token = init_trace(func.__name__, _parea_target_field, args, kwargs, func)
266297
output_as_list = check_multiple_return_values(func)
298+
result = None
267299
try:
268300
result = await func(*args, **kwargs)
269301
if not is_logging_disabled() and not log_omit_outputs:
270302
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
271-
return result
272303
except Exception as e:
273304
logger.error(f"Error occurred in function {func.__name__}, {e}")
274305
fill_trace_data(trace_id, {"error": traceback.format_exc()}, UpdateTraceScenario.ERROR)
275306
raise e
276307
finally:
277308
try:
278-
cleanup_trace(trace_id, start_time, context_token)
309+
if inspect.isasyncgen(result):
310+
return _wrap_async_iterator(result, trace_id, start_time, context_token)
311+
else:
312+
cleanup_trace(trace_id, start_time, context_token)
313+
# to not swallow any exceptions
314+
if result is not None:
315+
return result
279316
except Exception as e:
280317
logger.debug(f"Error occurred cleaning up trace for function {func.__name__}, {e}", exc_info=e)
281318

@@ -284,18 +321,24 @@ def wrapper(*args, **kwargs):
284321
_parea_target_field = kwargs.pop("_parea_target_field", None)
285322
trace_id, start_time, context_token = init_trace(func.__name__, _parea_target_field, args, kwargs, func)
286323
output_as_list = check_multiple_return_values(func)
324+
result = None
287325
try:
288326
result = func(*args, **kwargs)
289327
if not is_logging_disabled() and not log_omit_outputs:
290328
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
291-
return result
292329
except Exception as e:
293330
logger.error(f"Error occurred in function {func.__name__}, {e}")
294331
fill_trace_data(trace_id, {"error": traceback.format_exc()}, UpdateTraceScenario.ERROR)
295332
raise e
296333
finally:
297334
try:
298-
cleanup_trace(trace_id, start_time, context_token)
335+
if inspect.isgenerator(result):
336+
return _wrap_sync_iterator(result, trace_id, start_time, context_token)
337+
else:
338+
cleanup_trace(trace_id, start_time, context_token)
339+
# to not swallow any exceptions
340+
if result is not None:
341+
return result
299342
except Exception as e:
300343
logger.debug(f"Error occurred cleaning up trace for function {func.__name__}, {e}", exc_info=e)
301344

parea/wrapper/anthropic/anthropic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from datetime import datetime
66

77
from anthropic import AsyncMessageStreamManager, AsyncStream, Client, MessageStreamManager, Stream
8-
from anthropic.types import ContentBlockDeltaEvent, Message, MessageDeltaEvent, MessageStartEvent, TextBlock
8+
from anthropic.types import ContentBlockDeltaEvent, InputJSONDelta, Message, MessageDeltaEvent, MessageStartEvent, TextBlock, ToolUseBlock
99

1010
from parea.cache.cache import Cache
1111
from parea.helpers import timezone_aware_now
@@ -43,8 +43,6 @@ def init(self, log: Callable, cache: Cache, client: Client):
4343
def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Message]) -> Optional[Any]:
4444
if response:
4545
if len(response.content) > 1:
46-
from anthropic.types.beta.tools import ToolUseBlock
47-
4846
output_list = []
4947
for content in response.content:
5048
if isinstance(content, TextBlock):
@@ -185,7 +183,10 @@ def _update_accumulator_streaming(accumulator, info_from_response, chunk):
185183
if isinstance(chunk, MessageStartEvent):
186184
info_from_response["input_tokens"] = chunk.message.usage.input_tokens
187185
elif isinstance(chunk, ContentBlockDeltaEvent):
188-
accumulator["content"].append(chunk.delta.text)
186+
if isinstance(chunk.delta, InputJSONDelta):
187+
accumulator["content"].append(chunk.delta.partial_json)
188+
else:
189+
accumulator["content"].append(chunk.delta.text)
189190
if not info_from_response.get("first_token_timestamp"):
190191
info_from_response["first_token_timestamp"] = timezone_aware_now()
191192
elif isinstance(chunk, MessageDeltaEvent):

parea/wrapper/anthropic/stream_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from types import TracebackType
44
from typing import Callable
55

6-
from anthropic import AsyncMessageStreamManager, MessageStreamManager, Stream
6+
from anthropic import AsyncMessageStreamManager, AsyncStream, MessageStreamManager, Stream
77
from anthropic.types import Message
88

99

@@ -16,8 +16,8 @@ def __init__(self, stream: Stream, accumulator, info_from_response, update_accum
1616
self._info_from_response = info_from_response
1717

1818
def __getattr__(self, attr):
19-
# delegate attribute access to the original async_stream
20-
return getattr(self._async_stream, attr)
19+
# delegate attribute access to the original stream
20+
return getattr(self._stream, attr) if hasattr(self._stream, attr) else None
2121

2222
def __iter__(self):
2323
for chunk in self._stream:
@@ -28,7 +28,7 @@ def __iter__(self):
2828

2929

3030
class AnthropicAsyncStreamWrapper:
31-
def __init__(self, stream: Stream, accumulator, info_from_response, update_accumulator_streaming, final_processing_and_logging):
31+
def __init__(self, stream: AsyncStream, accumulator, info_from_response, update_accumulator_streaming, final_processing_and_logging):
3232
self._stream = stream
3333
self._final_processing_and_logging = final_processing_and_logging
3434
self._update_accumulator_streaming = update_accumulator_streaming
@@ -37,7 +37,7 @@ def __init__(self, stream: Stream, accumulator, info_from_response, update_accum
3737

3838
def __getattr__(self, attr):
3939
# delegate attribute access to the original async_stream
40-
return getattr(self._async_stream, attr)
40+
return getattr(self._stream, attr) if hasattr(self._stream, attr) else None
4141

4242
async def __aiter__(self):
4343
async for chunk in self._stream:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.201"
9+
version = "0.2.202"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)