diff --git a/src/agentscope/_version.py b/src/agentscope/_version.py index 48144febbb..6cb5253455 100644 --- a/src/agentscope/_version.py +++ b/src/agentscope/_version.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """The version of agentscope.""" -__version__ = "1.0.17" +__version__ = "1.0.18dev" diff --git a/src/agentscope/tool/_toolkit.py b/src/agentscope/tool/_toolkit.py index 4e5f39360f..d17257f65c 100644 --- a/src/agentscope/tool/_toolkit.py +++ b/src/agentscope/tool/_toolkit.py @@ -114,7 +114,7 @@ async def wrapped( return wrapper -class Toolkit(StateModule): +class Toolkit(StateModule): # pylint: disable=too-many-public-methods """Toolkit is the core module to register, manage and delete tool functions, MCP clients, Agent skills in AgentScope. @@ -179,6 +179,11 @@ def __init__( agent_skill_template or self._DEFAULT_AGENT_SKILL_TEMPLATE ) + # This is an experimental feature to allow the tool function to be + # executed in an async way + self._async_tasks: dict[str, asyncio.Task] = {} + self._async_results: dict[str, ToolResponse] = {} + def create_tool_group( self, group_name: str, @@ -294,6 +299,7 @@ def register_tool_function( "raise", "rename", ] = "raise", + async_execution: bool = False, ) -> None: """Register a tool function to the toolkit. @@ -350,6 +356,12 @@ def register_tool_function( - 'skip': skip the registration of the new tool function. - 'rename': rename the new tool function by appending a random suffix to make it unique. + async_execution (`bool`, defaults to `False`): + If this tool function is executed in an async manner, a + reminder with task id will be sent to the agent, allowing the + agent to view, cancel or check the status of the async task. + **This is an experimental feature and may cause unexpected + issues, please use it with caution.** """ # Arguments checking if group_name not in self.groups and group_name != "basic": @@ -457,6 +469,7 @@ def register_tool_function( extended_model=None, mcp_name=mcp_name, postprocess_func=postprocess_func, + async_execution=async_execution, ) if func_name in self.tools: @@ -669,6 +682,172 @@ async def remove_mcp_clients( ", ".join(to_removed), ) + async def _execute_tool_in_background( + self, + task_id: str, + tool_func: RegisteredToolFunction, + kwargs: dict, + partial_postprocess_func: ( + Callable[[ToolResponse], ToolResponse | None] + | Callable[[ToolResponse], Awaitable[ToolResponse | None]] + ) + | None, + ) -> None: + """Execute a tool function in the background and store the result. + + This function handles both streaming and non-streaming tool functions. + For streaming functions (generators/async generators), it accumulates + all chunks into a single final ToolResponse. + + Args: + task_id (`str`): + The unique identifier for this async task. + tool_func (`RegisteredToolFunction`): + The registered tool function to execute. + kwargs (`dict`): + The keyword arguments to pass to the tool function. + partial_postprocess_func (`Callable | None`): + Optional postprocess function to apply to the result. + """ + try: + # Execute the tool function + if inspect.iscoroutinefunction(tool_func.original_func): + try: + res = await tool_func.original_func(**kwargs) + except asyncio.CancelledError: + res = ToolResponse( + content=[ + TextBlock( + type="text", + text="" + "The tool call has been interrupted " + "by the user." + "", + ), + ], + stream=True, + is_last=True, + is_interrupted=True, + ) + else: + # When `tool_func.original_func` is Async generator function or + # Sync function + res = tool_func.original_func(**kwargs) + + except mcp.shared.exceptions.McpError as e: + res = ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error occurred when calling MCP tool: {e}", + ), + ], + ) + + except Exception as e: + res = ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error: {e}", + ), + ], + ) + + # Handle different return types and accumulate streaming results + final_result: ToolResponse = ToolResponse(content=[]) + + try: + # If return an async generator - accumulate all chunks + if isinstance(res, AsyncGenerator): + accumulated_content = [] + last_chunk = None + async for chunk in res: + accumulated_content.extend(chunk.content) + last_chunk = chunk + + # Create final accumulated response + final_result = ToolResponse( + content=accumulated_content, + stream=False, + is_last=True, + is_interrupted=last_chunk.is_interrupted + if last_chunk + else False, + ) + + # If return a sync generator - accumulate all chunks + elif isinstance(res, Generator): + accumulated_content = [] + last_chunk = None + for chunk in res: + accumulated_content.extend(chunk.content) + last_chunk = chunk + + # Create final accumulated response + final_result = ToolResponse( + content=accumulated_content, + stream=False, + is_last=True, + is_interrupted=last_chunk.is_interrupted + if last_chunk + else False, + ) + + elif isinstance(res, ToolResponse): + final_result = res + + else: + raise TypeError( + "The tool function must return a ToolResponse " + "object, or an AsyncGenerator/Generator of " + "ToolResponse objects, " + f"but got {type(res)}.", + ) + + # Apply postprocess function if provided + if partial_postprocess_func: + from .._utils._common import _execute_async_or_sync_func + + processed_result = await _execute_async_or_sync_func( + partial_postprocess_func, + final_result, + ) + if processed_result: + final_result = processed_result + + except asyncio.CancelledError: + # Handle cancellation during execution + final_result = ToolResponse( + content=[ + TextBlock( + type="text", + text="" + "The tool call has been cancelled by the user." + "", + ), + ], + is_interrupted=True, + is_last=True, + ) + + except Exception as e: + # Handle any other errors during execution + final_result = ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error during async execution: {e}", + ), + ], + ) + + finally: + # Store the result and remove from active tasks + self._async_results[task_id] = final_result + if task_id in self._async_tasks: + self._async_tasks.pop(task_id) + @trace_toolkit @_apply_middlewares async def call_tool_function( @@ -749,6 +928,45 @@ async def call_tool_function( else: partial_postprocess_func = None + # Check if async execution is enabled + if tool_func.async_execution: + # Generate a unique task ID + task_id = shortuuid.uuid() + + # Create and store the background task + task = asyncio.create_task( + self._execute_tool_in_background( + task_id=task_id, + tool_func=tool_func, + kwargs=kwargs, + partial_postprocess_func=partial_postprocess_func, + ), + ) + self._async_tasks[task_id] = task + + # Return a response with the task ID + return _object_wrapper( + ToolResponse( + content=[ + TextBlock( + type="text", + text=f"" + f"Tool '{tool_call['name']}' is executing " + f"asynchronously. " + f"Task ID: {task_id}. " + f"Use view_task('{task_id}') to check " + f"status, " + f"wait_task('{task_id}') to wait for " + f"completion, " + f"or cancel_task('{task_id}') to cancel " + f"the task." + f"", + ), + ], + ), + None, + ) + # Async function try: if inspect.iscoroutinefunction(tool_func.original_func): @@ -1314,3 +1532,148 @@ async def my_middleware( # Simply append the middleware to the list # The @apply_middlewares decorator will handle the execution self._middlewares.append(middleware) + + async def view_task(self, task_id: str) -> ToolResponse: + """View the status of an async tool task by its task ID. + + Args: + task_id (`str`): + The ID of the async tool task. + + Returns: + `ToolResponse`: + The tool response containing the status information of the + async task. + """ + if ( + task_id not in self._async_tasks + and task_id not in self._async_results + ): + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"InvalidTaskIdError: Cannot find async " + f"task with ID {task_id}.", + ), + ], + ) + + if task_id in self._async_tasks: + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Task {task_id} is still running.", + ), + ], + ) + + # If the task is completed, return the result or error + return self._async_results.pop(task_id) + + async def cancel_task(self, task_id: str) -> ToolResponse: + """Cancel an async tool task by its task ID. + + Args: + task_id (`str`): + The ID of the async tool task. + + Returns: + `ToolResponse`: + The tool response indicating whether the cancellation was + successful. + """ + if ( + task_id not in self._async_tasks + and task_id not in self._async_results + ): + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"InvalidTaskIdError: Cannot find async " + f"task with ID {task_id}.", + ), + ], + ) + + if task_id in self._async_results: + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Task {task_id} has already completed " + f"and cannot be cancelled.", + ), + ], + ) + + # Cancel the running task + task = self._async_tasks.pop(task_id) + task.cancel() + + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Task {task_id} has been cancelled.", + ), + ], + ) + + async def wait_task( + self, + task_id: str, + timeout: float = 10, + ) -> ToolResponse: + """Wait for an async tool execution to complete by its task ID. Note + the timeout shouldn't be too large, you can check the task status + by this tool every short period of time to avoid long waiting time. + + Args: + task_id (`str`): + The ID of the async tool task. + timeout (`float`, defaults to `10`): + The maximum time to wait for the task to complete, in seconds. + + Returns: + `ToolResponse`: + The tool response containing the result of the async task if + it completes within the timeout, or an error message if the + task is still running after the timeout. + """ + if ( + task_id not in self._async_tasks + and task_id not in self._async_results + ): + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"InvalidTaskIdError: Cannot find async " + f"task with ID {task_id}.", + ), + ], + ) + + if task_id in self._async_results: + return self._async_results.pop(task_id) + + # Wait for the running task to complete or timeout + task = self._async_tasks[task_id] + try: + await asyncio.wait_for(asyncio.shield(task), timeout=timeout) + except asyncio.TimeoutError: + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Task {task_id} is still running after " + f"waiting for {timeout} seconds.", + ), + ], + ) + + # If the task is completed, return the result or error + return self._async_results.pop(task_id) diff --git a/src/agentscope/tool/_types.py b/src/agentscope/tool/_types.py index 7d8a52d54b..3dbc8cfe22 100644 --- a/src/agentscope/tool/_types.py +++ b/src/agentscope/tool/_types.py @@ -54,6 +54,10 @@ class RegisteredToolFunction: returns `None`, the tool result will be returned as is. If it returns a `ToolResponse`, the returned block will be used as the final tool response.""" + async_execution: bool = False + """If this tool function is executed in an async manner, a reminder with + task id will be sent to the agent, allowing the agent to view, cancel or + check the status of the async task.""" @property def extended_json_schema(self) -> dict: diff --git a/src/agentscope/tuner/_config.py b/src/agentscope/tuner/_config.py index 917cb0ee45..a467b12602 100644 --- a/src/agentscope/tuner/_config.py +++ b/src/agentscope/tuner/_config.py @@ -48,7 +48,7 @@ def _to_trinity_config( "%Y%m%d%H%M%S", ) - _set_if_not_none(config, "monitor", monitor_type) + _set_if_not_none(config.monitor, "monitor_type", monitor_type) workflow_name = "agentscope_workflow_adapter_v1" if train_dataset is not None: diff --git a/tests/toolkit_async_execution_test.py b/tests/toolkit_async_execution_test.py new file mode 100644 index 0000000000..b66ae53d28 --- /dev/null +++ b/tests/toolkit_async_execution_test.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# pylint: disable=protected-access, redefined-builtin +"""Test async execution functionality in Toolkit.""" +import asyncio +from typing import AsyncGenerator +from unittest import IsolatedAsyncioTestCase + +from agentscope.message import ToolUseBlock, TextBlock +from agentscope.tool import ToolResponse, Toolkit + + +def _text(res: ToolResponse) -> str: + """Extract concatenated text from a ToolResponse.""" + return "".join( + block["text"] for block in res.content if block.get("type") == "text" + ) + + +async def slow_async_func(delay: float) -> ToolResponse: + """A slow async function for testing async execution. + + Args: + delay (`float`): + The time to sleep in seconds. + """ + await asyncio.sleep(delay) + return ToolResponse( + content=[TextBlock(type="text", text=f"done after {delay}s")], + ) + + +async def slow_async_generator_func( + delay: float, +) -> AsyncGenerator[ToolResponse, None]: + """A slow async generator function for testing async execution. + + Args: + delay (`float`): + The time to sleep in seconds. + """ + yield ToolResponse( + content=[TextBlock(type="text", text="chunk1")], + stream=True, + ) + await asyncio.sleep(delay) + yield ToolResponse( + content=[TextBlock(type="text", text="chunk1chunk2")], + stream=True, + is_last=True, + ) + + +class ToolkitAsyncExecutionTest(IsolatedAsyncioTestCase): + """Tests for async execution via call_tool_function.""" + + async def asyncSetUp(self) -> None: + self.toolkit = Toolkit() + + async def asyncTearDown(self) -> None: + self.toolkit = None + + def _make_tool_call(self, name: str, input: dict) -> ToolUseBlock: + return ToolUseBlock( + type="tool_use", + id="test-id", + name=name, + input=input, + ) + + async def _start_async_task(self, delay: float) -> str | None: + """Register slow_async_func, call it with async_execution=True, + and return the task_id from the response text.""" + self.toolkit.register_tool_function( + slow_async_func, + async_execution=True, + ) + res = await self.toolkit.call_tool_function( + self._make_tool_call("slow_async_func", {"delay": delay}), + ) + task_id: str | None = None + async for chunk in res: + # The response text contains "Task ID: " + for block in chunk.content: + if "Task ID:" in block["text"]: + task_id = ( + block["text"] + .split("Task ID:")[1] + .strip() + .split(".")[0] + .strip() + ) + self.assertIsNotNone( + task_id, + "task_id should be present in response", + ) + return task_id + + # ------------------------------------------------------------------ + # 1. view_task: still running vs completed + # ------------------------------------------------------------------ + + async def test_view_task_still_running(self) -> None: + """view_task returns 'still running' when task is not yet done.""" + task_id = await self._start_async_task(delay=5.0) + + # Task should still be running immediately after launch + res = await self.toolkit.view_task(task_id) + self.assertIn(task_id, _text(res)) + self.assertIn("still running", _text(res)) + + # Clean up: cancel the task so it doesn't linger + await self.toolkit.cancel_task(task_id) + + async def test_view_task_completed(self) -> None: + """view_task returns the result once the task has finished.""" + task_id = await self._start_async_task(delay=0.05) + + # Wait for the task to finish + await asyncio.sleep(0.2) + + res = await self.toolkit.view_task(task_id) + self.assertIn("done after 0.05s", _text(res)) + + # Result should have been consumed; task_id no longer tracked + self.assertNotIn(task_id, self.toolkit._async_results) + self.assertNotIn(task_id, self.toolkit._async_tasks) + + # ------------------------------------------------------------------ + # 2. wait_task: completes within timeout vs times out + # ------------------------------------------------------------------ + + async def test_wait_task_completes_within_timeout(self) -> None: + """wait_task returns the result when task finishes before timeout.""" + task_id = await self._start_async_task(delay=0.05) + + res = await self.toolkit.wait_task(task_id, timeout=5.0) + self.assertIn("done after 0.05s", _text(res)) + + async def test_wait_task_timeout(self) -> None: + """wait_task returns a timeout message when task exceeds timeout.""" + task_id = await self._start_async_task(delay=5.0) + + res = await self.toolkit.wait_task(task_id, timeout=0.05) + self.assertIn("still running", _text(res)) + + # Task should still be alive after timeout + self.assertIn(task_id, self.toolkit._async_tasks) + + # Clean up + await self.toolkit.cancel_task(task_id) + + # ------------------------------------------------------------------ + # 3. cancel_task + # ------------------------------------------------------------------ + + async def test_cancel_task(self) -> None: + """cancel_task cancels a running task successfully.""" + task_id = await self._start_async_task(delay=5.0) + + res = await self.toolkit.cancel_task(task_id) + self.assertIn("cancelled", _text(res).lower()) + + # Task should be removed from active tasks + self.assertNotIn(task_id, self.toolkit._async_tasks) + + async def test_cancel_already_completed_task(self) -> None: + """cancel_task on a completed task returns an appropriate message.""" + task_id = await self._start_async_task(delay=0.05) + + # Wait for completion + await asyncio.sleep(0.2) + + res = await self.toolkit.cancel_task(task_id) + self.assertIn("already completed", _text(res).lower()) + + # ------------------------------------------------------------------ + # 4. Streaming tool function: chunks are accumulated into one result + # ------------------------------------------------------------------ + + async def test_async_generator_result_is_accumulated(self) -> None: + """Streaming tool results are accumulated into a single + ToolResponse.""" + self.toolkit.register_tool_function( + slow_async_generator_func, + async_execution=True, + ) + res = await self.toolkit.call_tool_function( + self._make_tool_call( + "slow_async_generator_func", + {"delay": 0.05}, + ), + ) + task_id = None + async for chunk in res: + for block in chunk.content: + if "Task ID:" in block["text"]: + task_id = ( + block["text"] + .split("Task ID:")[1] + .strip() + .split(".")[0] + .strip() + ) + self.assertIsNotNone(task_id) + + # Wait for the generator to finish + result = await self.toolkit.wait_task(task_id, timeout=5.0) + + # Both chunks' content should be present in the accumulated result + text = _text(result) + self.assertIn("chunk1", text) + self.assertIn("chunk2", text) + + # ------------------------------------------------------------------ + # 5. Invalid task_id + # ------------------------------------------------------------------ + + async def test_view_invalid_task_id(self) -> None: + """view_task with unknown task_id returns an error message.""" + res = await self.toolkit.view_task("nonexistent-id") + self.assertIn("InvalidTaskIdError", _text(res)) + + async def test_cancel_invalid_task_id(self) -> None: + """cancel_task with unknown task_id returns an error message.""" + res = await self.toolkit.cancel_task("nonexistent-id") + self.assertIn("InvalidTaskIdError", _text(res)) + + async def test_wait_invalid_task_id(self) -> None: + """wait_task with unknown task_id returns an error message.""" + res = await self.toolkit.wait_task("nonexistent-id", timeout=1.0) + self.assertIn("InvalidTaskIdError", _text(res))