Skip to content

feat: tool_events #1422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/google/adk/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class Event(LlmResponse):
conversation history.
"""

tool_events: Optional[list[Event]] = None
"""The tool event in this event.
This is used to store the tool events that are generated by the agent.
"""


# The following are computed fields.
# Do not assign the ID. It will be assigned by the session.
id: str = ''
Expand Down
21 changes: 16 additions & 5 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ async def handle_function_calls_async(
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response: Optional[dict] = None
function_response: Optional[dict] = None
tool_events: Optional[list[Event]] = None

for callback in agent.canonical_before_tool_callbacks:
function_response = callback(
Expand All @@ -168,6 +169,9 @@ async def handle_function_calls_async(
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
if isinstance(function_response, tuple):
function_response, tool_events = function_response


for callback in agent.canonical_after_tool_callbacks:
altered_function_response = callback(
Expand All @@ -189,7 +193,7 @@ async def handle_function_calls_async(

# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
tool, function_response, tool_context, invocation_context, tool_events
)
trace_tool_call(
tool=tool,
Expand Down Expand Up @@ -228,6 +232,7 @@ async def handle_function_calls_live(
function_calls = function_call_event.get_function_calls()

function_response_events: list[Event] = []
tool_events: Optional[list[Event]] = None
for function_call in function_calls:
tool, tool_context = _get_tool_and_context(
invocation_context, function_call_event, function_call, tools_dict
Expand All @@ -250,7 +255,7 @@ async def handle_function_calls_live(
function_response = await function_response

if not function_response:
function_response = await _process_function_live_helper(
function_response, tool_events = await _process_function_live_helper(
tool, tool_context, function_call, function_args, invocation_context
)

Expand Down Expand Up @@ -283,7 +288,7 @@ async def handle_function_calls_live(

# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
tool, function_response, tool_context, invocation_context, tool_events
)
trace_tool_call(
tool=tool,
Expand Down Expand Up @@ -313,6 +318,7 @@ async def _process_function_live_helper(
tool, tool_context, function_call, function_args, invocation_context
):
function_response = None
tool_events = None
# Check if this is a stop_streaming function call
if (
function_call.name == 'stop_streaming'
Expand Down Expand Up @@ -401,7 +407,10 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
return function_response
if isinstance(function_response, tuple):
function_response, tool_events = function_response

return function_response, tool_events


def _get_tool_and_context(
Expand Down Expand Up @@ -454,6 +463,7 @@ def __build_response_event(
function_result: dict[str, object],
tool_context: ToolContext,
invocation_context: InvocationContext,
tool_events: Optional[list[Event]],
) -> Event:
# Specs requires the result to be a dict.
if not isinstance(function_result, dict):
Expand All @@ -475,6 +485,7 @@ def __build_response_event(
content=content,
actions=tool_context.actions,
branch=invocation_context.branch,
tool_events=tool_events
)

return function_response_event
Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ._forwarding_artifact_service import ForwardingArtifactService
from .base_tool import BaseTool
from .tool_context import ToolContext
from ..events.event import Event

if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent
Expand Down Expand Up @@ -124,10 +125,12 @@ async def run_async(
)

last_event = None
tool_events: list[Event] = []
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
):
# Forward state delta to parent session.
tool_events.append(event)
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
last_event = event
Expand All @@ -141,4 +144,4 @@ async def run_async(
).model_dump(exclude_none=True)
else:
tool_result = merged_text
return tool_result
return tool_result, tool_events