Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 173 additions & 19 deletions pyagentspec/src/pyagentspec/adapters/langgraph/_langgraphconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option.


from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Literal, Optional, Tuple, Union
from uuid import uuid4

import httpx
Expand All @@ -32,9 +32,13 @@
langgraph_graph,
langgraph_prebuilt,
)
from pyagentspec.adapters.langgraph.tracing import AgentSpecCallbackHandler
from pyagentspec.adapters.langgraph.tracing import (
AgentSpecLlmCallbackHandler,
AgentSpecToolCallbackHandler,
)
from pyagentspec.agent import Agent as AgentSpecAgent
from pyagentspec.flows.edges.controlflowedge import ControlFlowEdge
from pyagentspec.flows.edges import ControlFlowEdge as AgentSpecControlFlowEdge
from pyagentspec.flows.edges import DataFlowEdge as AgentSpecDataFlowEdge
from pyagentspec.flows.flow import Flow as AgentSpecFlow
from pyagentspec.flows.node import Node as AgentSpecNode
from pyagentspec.flows.nodes import AgentNode as AgentSpecAgentNode
Expand Down Expand Up @@ -63,6 +67,12 @@
from pyagentspec.tools import RemoteTool as AgentSpecRemoteTool
from pyagentspec.tools import ServerTool as AgentSpecServerTool
from pyagentspec.tools import Tool as AgentSpecTool
from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd
from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart
from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd
from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart
from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan
from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan


class SchemaRegistry:
Expand Down Expand Up @@ -210,17 +220,12 @@ def _convert(
config: RunnableConfig,
) -> Any:
if isinstance(agentspec_component, AgentSpecAgent):
callback = AgentSpecCallbackHandler(
llm_config=agentspec_component.llm_config,
tools=agentspec_component.tools,
)
config_with_callbacks = _add_callback_to_runnable_config(callback, config)
return self._agent_convert_to_langgraph(
agentspec_component,
tool_registry=tool_registry,
converted_components=converted_components,
checkpointer=checkpointer,
config=config_with_callbacks,
config=config,
)
elif isinstance(agentspec_component, AgentSpecLlmConfig):
return self._llm_convert_to_langgraph(agentspec_component, config=config)
Expand Down Expand Up @@ -263,7 +268,7 @@ def _convert(
)

def _create_control_flow(
self, control_flow_connections: List[ControlFlowEdge]
self, control_flow_connections: List[AgentSpecControlFlowEdge]
) -> "ControlFlow":
control_flow: "ControlFlow" = {}
for control_flow_edge in control_flow_connections:
Expand Down Expand Up @@ -299,6 +304,7 @@ def _flow_convert_to_langgraph(
graph_builder = StateGraph(
FlowStateSchema, input_schema=FlowInputSchema, output_schema=FlowOutputSchema
)

graph_builder.add_edge(langgraph_graph.START, flow.start_node.id)

node_executors = {
Expand Down Expand Up @@ -342,12 +348,99 @@ def _find_property(properties: List[AgentSpecProperty], name: str) -> AgentSpecP
for node_id, node_executor in node_executors.items():
graph_builder.add_node(node_id, node_executor)

for data_flow_edge in flow.data_flow_connections or []:
data_flow_connections: List[AgentSpecDataFlowEdge] = []
if flow.data_flow_connections is None:
# We manually create data flow connections if they are not given in the flow
# This is the conversion recommended in the Agent Spec language specification
for source_node in flow.nodes:
for destination_node in flow.nodes:
for source_output in source_node.outputs or []:
for destination_input in destination_node.inputs or []:
if source_output.title == destination_input.title:
data_flow_connections.append(
AgentSpecDataFlowEdge(
name=f"{source_node.name}-{destination_node.name}-{source_output.title}",
source_node=source_node,
source_output=source_output.title,
destination_node=destination_node,
destination_input=destination_input.title,
)
)
else:
data_flow_connections = flow.data_flow_connections

for data_flow_edge in data_flow_connections:
node_executors[data_flow_edge.source_node.id].attach_edge(data_flow_edge)

control_flow: "ControlFlow" = self._create_control_flow(flow.control_flow_connections)
self._add_conditional_edges_to_graph(control_flow, graph_builder)
return graph_builder.compile(checkpointer=checkpointer)
compiled_graph = graph_builder.compile(checkpointer=checkpointer)

# To enable flow execution traces monkey patch all the functions that invoke the compiled graph

original_stream = compiled_graph.stream

def patch_with_flow_execution_span(*args: Any, **kwargs: Any) -> Generator[Any, Any, None]:
span_name = f"FlowExecution[{flow.name}]"
inputs = kwargs.get("input", {})
if not isinstance(inputs, dict):
inputs = {}
with AgentSpecFlowExecutionSpan(name=span_name, flow=flow) as span:
span.add_event(AgentSpecFlowExecutionStart(flow=flow, inputs=inputs))
original_result: dict[str, Any] | Any = {}
result: dict[str, Any]
# This is going to patch stream and astream, that return iterators and yield chunks
for chunk in original_stream(*args, **kwargs):
yield chunk
if isinstance(chunk, tuple):
original_result = chunk[1]
if not isinstance(original_result, dict):
result = {}
else:
result = original_result
span.add_event(
AgentSpecFlowExecutionEnd(
flow=flow,
outputs=result.get("outputs", {}),
branch_selected=result.get("node_execution_details", {}).get("branch", ""),
)
)

original_astream = compiled_graph.astream

async def patch_async_with_flow_execution_span(
*args: Any, **kwargs: Any
) -> AsyncGenerator[Any, Any]:
span_name = f"FlowExecution[{flow.name}]"
inputs = kwargs.get("input", {})
if not isinstance(inputs, dict):
inputs = {}
with AgentSpecFlowExecutionSpan(name=span_name, flow=flow) as span:
span.add_event(AgentSpecFlowExecutionStart(flow=flow, inputs=inputs))
original_result: dict[str, Any] | Any = {}
result: dict[str, Any]
# This is going to patch stream and astream, that return iterators and yield chunks
async for chunk in original_astream(*args, **kwargs):
yield chunk
if isinstance(chunk, tuple):
original_result = chunk[1]
if not isinstance(original_result, dict):
result = {}
else:
result = original_result
span.add_event(
AgentSpecFlowExecutionEnd(
flow=flow,
outputs=result.get("outputs", {}),
branch_selected=result.get("node_execution_details", {}).get("branch", ""),
)
)

# Monkey patch invocation functions to inject tracing
# No need to patch `(a)invoke` as the internally use `(a)stream`
compiled_graph.stream = patch_with_flow_execution_span # type: ignore
compiled_graph.astream = patch_async_with_flow_execution_span # type: ignore
return compiled_graph

def _node_convert_to_langgraph(
self,
Expand Down Expand Up @@ -593,7 +686,7 @@ def _remote_tool(**kwargs: Any) -> Any:
description=remote_tool.description or "",
args_schema=args_model,
func=_remote_tool,
callbacks=config.get("callbacks"),
callbacks=[AgentSpecToolCallbackHandler(tool=remote_tool)],
)
return structured_tool

Expand Down Expand Up @@ -629,7 +722,7 @@ def _server_tool_convert_to_langgraph(
description=description,
args_schema=args_model, # model class, not a dict
func=tool_obj,
callbacks=config.get("callbacks"),
callbacks=[AgentSpecToolCallbackHandler(tool=agentspec_server_tool)],
)
return wrapped

Expand Down Expand Up @@ -666,13 +759,15 @@ def client_tool(*args: Any, **kwargs: Any) -> Any:
description=agentspec_client_tool.description or "",
args_schema=args_model,
func=client_tool,
# We do not add the tool execution callback here as it's not expected for client tools
)
return structured_tool

def _create_react_agent_with_given_info(
self,
name: str,
system_prompt: str,
agent: AgentSpecAgent,
llm_config: AgentSpecLlmConfig,
tools: List[AgentSpecTool],
inputs: List[AgentSpecProperty],
Expand Down Expand Up @@ -726,7 +821,7 @@ def _create_react_agent_with_given_info(
if outputs:
output_model = _create_pydantic_model_from_properties("AgentOutputModel", outputs)

return langgraph_prebuilt.create_react_agent(
compiled_graph = langgraph_prebuilt.create_react_agent(
name=name,
model=model,
tools=langgraph_tools,
Expand All @@ -736,6 +831,62 @@ def _create_react_agent_with_given_info(
state_schema=input_model,
)

# To enable flow execution traces monkey patch all the functions that invoke the compiled graph

original_stream = compiled_graph.stream

def patch_with_agent_execution_span(*args: Any, **kwargs: Any) -> Generator[Any, Any, Any]:
span_name = f"AgentExecution[{agent.name}]"
inputs = kwargs.get("input", {})
if not isinstance(inputs, dict):
inputs = {}
with AgentSpecAgentExecutionSpan(name=span_name, agent=agent) as span:
span.add_event(AgentSpecAgentExecutionStart(agent=agent, inputs=inputs))
original_result: dict[str, Any] | Any = {}
result: dict[str, Any]
# This is going to patch stream and astream, that return iterators and yield chunks
for chunk in original_stream(*args, **kwargs):
yield chunk
if isinstance(chunk, tuple):
original_result = chunk[1]
if not isinstance(original_result, dict):
result = {}
else:
result = original_result
outputs = dict(result.get("structured_response", {}))
span.add_event(AgentSpecAgentExecutionEnd(agent=agent, outputs=outputs))

original_astream = compiled_graph.astream

async def patch_async_with_agent_execution_span(
*args: Any, **kwargs: Any
) -> AsyncGenerator[Any, Any]:
span_name = f"AgentExecution[{agent.name}]"
inputs = kwargs.get("input", {})
if not isinstance(inputs, dict):
inputs = {}
with AgentSpecAgentExecutionSpan(name=span_name, agent=agent) as span:
span.add_event(AgentSpecAgentExecutionStart(agent=agent, inputs=inputs))
original_result: dict[str, Any] | Any = {}
result: dict[str, Any]
# This is going to patch stream and astream, that return iterators and yield chunks
async for chunk in original_astream(*args, **kwargs):
yield chunk
if isinstance(chunk, tuple):
original_result = chunk[1]
if not isinstance(original_result, dict):
result = {}
else:
result = original_result
outputs = dict(result.get("structured_response", {}))
span.add_event(AgentSpecAgentExecutionEnd(agent=agent, outputs=outputs))

# Monkey patch invocation functions to inject tracing
# No need to patch `(a)invoke` as the internally use `(a)stream`
compiled_graph.stream = patch_with_agent_execution_span # type: ignore
compiled_graph.astream = patch_async_with_agent_execution_span # type: ignore
return compiled_graph

def _agent_convert_to_langgraph(
self,
agentspec_component: AgentSpecAgent,
Expand All @@ -747,6 +898,7 @@ def _agent_convert_to_langgraph(
return self._create_react_agent_with_given_info(
name=agentspec_component.name,
system_prompt=agentspec_component.system_prompt,
agent=agentspec_component,
llm_config=agentspec_component.llm_config,
tools=agentspec_component.tools,
inputs=agentspec_component.inputs or [],
Expand All @@ -773,6 +925,8 @@ def _llm_convert_to_langgraph(
if isinstance(llm_config, (OpenAiCompatibleConfig, OpenAiConfig)):
use_responses_api = llm_config.api_type == OpenAIAPIType.RESPONSES

callbacks: List[BaseCallbackHandler] = [AgentSpecLlmCallbackHandler(llm_config=llm_config)]

if isinstance(llm_config, VllmConfig):
from langchain_openai import ChatOpenAI

Expand All @@ -781,7 +935,7 @@ def _llm_convert_to_langgraph(
api_key=SecretStr("EMPTY"),
base_url=_prepare_openai_compatible_url(llm_config.url),
use_responses_api=use_responses_api,
callbacks=config.get("callbacks"),
callbacks=callbacks,
**generation_config,
)
elif isinstance(llm_config, OllamaConfig):
Expand All @@ -795,7 +949,7 @@ def _llm_convert_to_langgraph(
return ChatOllama(
base_url=llm_config.url,
model=llm_config.model_id,
callbacks=config.get("callbacks"),
callbacks=callbacks,
**generation_config,
)
elif isinstance(llm_config, OpenAiConfig):
Expand All @@ -804,7 +958,7 @@ def _llm_convert_to_langgraph(
return ChatOpenAI(
model=llm_config.model_id,
use_responses_api=use_responses_api,
callbacks=config.get("callbacks"),
callbacks=callbacks,
**generation_config,
)
elif isinstance(llm_config, OpenAiCompatibleConfig):
Expand All @@ -814,7 +968,7 @@ def _llm_convert_to_langgraph(
model=llm_config.model_id,
base_url=_prepare_openai_compatible_url(llm_config.url),
use_responses_api=use_responses_api,
callbacks=config.get("callbacks"),
callbacks=callbacks,
**generation_config,
)
else:
Expand Down
Loading