@@ -166,6 +166,7 @@ def trace(
166
166
overwrite_trace_id : Optional [str ] = None ,
167
167
overwrite_inputs : Optional [Dict [str , Any ]] = None ,
168
168
log_sample_rate : Optional [float ] = 1.0 ,
169
+ fn_transform_generator_outputs : Callable [[List [Any ]], str ] = None ,
169
170
):
170
171
def init_trace (func_name , _parea_target_field , args , kwargs , func ) -> Tuple [str , datetime , contextvars .Token ]:
171
172
start_time = timezone_aware_now ()
@@ -258,24 +259,60 @@ def cleanup_trace(trace_id: str, start_time: datetime, context_token: contextvar
258
259
thread_eval_funcs_then_log (trace_id , eval_funcs )
259
260
trace_context .reset (context_token )
260
261
262
+ def _handle_iterator_cleanup (items , trace_id , start_time , context_token ):
263
+ if fn_transform_generator_outputs :
264
+ result = fn_transform_generator_outputs (items )
265
+ elif all (isinstance (item , str ) for item in items ):
266
+ result = "" .join (items )
267
+ else :
268
+ result = ""
269
+ if not is_logging_disabled () and not log_omit_outputs :
270
+ fill_trace_data (trace_id , {"result" : result }, UpdateTraceScenario .RESULT )
271
+
272
+ cleanup_trace (trace_id , start_time , context_token )
273
+
274
+ async def _wrap_async_iterator (iterator , trace_id , start_time , context_token ):
275
+ items = []
276
+ try :
277
+ async for item in iterator :
278
+ items .append (item )
279
+ yield item
280
+ finally :
281
+ _handle_iterator_cleanup (items , trace_id , start_time , context_token )
282
+
283
+ def _wrap_sync_iterator (iterator , trace_id , start_time , context_token ):
284
+ items = []
285
+ try :
286
+ for item in iterator :
287
+ items .append (item )
288
+ yield item
289
+ finally :
290
+ _handle_iterator_cleanup (items , trace_id , start_time , context_token )
291
+
261
292
def decorator (func ):
262
293
@wraps (func )
263
294
async def async_wrapper (* args , ** kwargs ):
264
295
_parea_target_field = kwargs .pop ("_parea_target_field" , None )
265
296
trace_id , start_time , context_token = init_trace (func .__name__ , _parea_target_field , args , kwargs , func )
266
297
output_as_list = check_multiple_return_values (func )
298
+ result = None
267
299
try :
268
300
result = await func (* args , ** kwargs )
269
301
if not is_logging_disabled () and not log_omit_outputs :
270
302
fill_trace_data (trace_id , {"result" : result , "output_as_list" : output_as_list , "eval_funcs_names" : eval_funcs_names }, UpdateTraceScenario .RESULT )
271
- return result
272
303
except Exception as e :
273
304
logger .error (f"Error occurred in function { func .__name__ } , { e } " )
274
305
fill_trace_data (trace_id , {"error" : traceback .format_exc ()}, UpdateTraceScenario .ERROR )
275
306
raise e
276
307
finally :
277
308
try :
278
- cleanup_trace (trace_id , start_time , context_token )
309
+ if inspect .isasyncgen (result ):
310
+ return _wrap_async_iterator (result , trace_id , start_time , context_token )
311
+ else :
312
+ cleanup_trace (trace_id , start_time , context_token )
313
+ # to not swallow any exceptions
314
+ if result is not None :
315
+ return result
279
316
except Exception as e :
280
317
logger .debug (f"Error occurred cleaning up trace for function { func .__name__ } , { e } " , exc_info = e )
281
318
@@ -284,18 +321,24 @@ def wrapper(*args, **kwargs):
284
321
_parea_target_field = kwargs .pop ("_parea_target_field" , None )
285
322
trace_id , start_time , context_token = init_trace (func .__name__ , _parea_target_field , args , kwargs , func )
286
323
output_as_list = check_multiple_return_values (func )
324
+ result = None
287
325
try :
288
326
result = func (* args , ** kwargs )
289
327
if not is_logging_disabled () and not log_omit_outputs :
290
328
fill_trace_data (trace_id , {"result" : result , "output_as_list" : output_as_list , "eval_funcs_names" : eval_funcs_names }, UpdateTraceScenario .RESULT )
291
- return result
292
329
except Exception as e :
293
330
logger .error (f"Error occurred in function { func .__name__ } , { e } " )
294
331
fill_trace_data (trace_id , {"error" : traceback .format_exc ()}, UpdateTraceScenario .ERROR )
295
332
raise e
296
333
finally :
297
334
try :
298
- cleanup_trace (trace_id , start_time , context_token )
335
+ if inspect .isgenerator (result ):
336
+ return _wrap_sync_iterator (result , trace_id , start_time , context_token )
337
+ else :
338
+ cleanup_trace (trace_id , start_time , context_token )
339
+ # to not swallow any exceptions
340
+ if result is not None :
341
+ return result
299
342
except Exception as e :
300
343
logger .debug (f"Error occurred cleaning up trace for function { func .__name__ } , { e } " , exc_info = e )
301
344
0 commit comments