diff --git a/src/aiq/profiler/callbacks/langchain_callback_handler.py b/src/aiq/profiler/callbacks/langchain_callback_handler.py index 88fc7696..07bd16ed 100644 --- a/src/aiq/profiler/callbacks/langchain_callback_handler.py +++ b/src/aiq/profiler/callbacks/langchain_callback_handler.py @@ -253,7 +253,7 @@ async def on_tool_start( usage_info=UsageInfo(token_usage=TokenUsageBaseModel())) self.step_manager.push_intermediate_step(stats) - self._run_id_to_tool_input[str(run_id)] = input_str + self._run_id_to_tool_input[str(run_id)] = copy.deepcopy(inputs) self._run_id_to_start_time[str(run_id)] = time.time() async def on_tool_end( @@ -265,14 +265,15 @@ async def on_tool_end( **kwargs: Any, ) -> Any: + inputs = self._run_id_to_tool_input.get(str(run_id), "") + stats = IntermediateStepPayload(event_type=IntermediateStepType.TOOL_END, span_event_timestamp=self._run_id_to_start_time.get(str(run_id), time.time()), framework=LLMFrameworkEnum.LANGCHAIN, name=kwargs.get("name", ""), UUID=str(run_id), - metadata=TraceMetadata(tool_outputs=output), + metadata=TraceMetadata(tool_inputs=inputs, tool_outputs=output), usage_info=UsageInfo(token_usage=TokenUsageBaseModel()), - data=StreamEventData(input=self._run_id_to_tool_input.get(str(run_id), ""), - output=output)) + data=StreamEventData(input=inputs, output=output)) self.step_manager.push_intermediate_step(stats)