Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trae_agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def run(
extra_args: dict[str, str] | None = None,
tool_names: list[str] | None = None,
):
await self.agent.reset()
self.agent.new_task(task, extra_args, tool_names)

if self.agent.allow_mcp_servers:
Expand Down
16 changes: 16 additions & 0 deletions trae_agent/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
tools_registry[tool_name](model_provider=self._model_config.model_provider.provider)
for tool_name in agent_config.tools
]
# Keep a copy of the base tools configuration
self._base_tools: list[Tool] = list(self._tools)

self.docker_keep = docker_keep
self.docker_manager: DockerManager | None = None
original_tool_executor = ToolExecutor(self._tools)
Expand Down Expand Up @@ -282,6 +285,19 @@ async def cleanup_mcp_clients(self) -> None:
"""Clean up MCP clients. Override in subclasses that use MCP."""
pass

async def reset(self) -> None:
"""Reset the agent state."""
# Reset all tools
for tool in self._tools:
await tool.reset()

# Clear LLM history
self._llm_client.clear_history()

# Restart Docker shell if active to clear env vars and cwd
if self.docker_manager:
self.docker_manager.restart_shell()

def _update_cli_console(
self, step: AgentStep | None = None, agent_execution: AgentExecution | None = None
) -> None:
Expand Down
9 changes: 9 additions & 0 deletions trae_agent/agent/docker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ def stop(self):

self.container = None

def restart_shell(self):
"""Restarts the persistent shell to clear session state (env vars, cwd)."""
if self.shell and self.shell.isalive():
self.shell.close(force=True)
self.shell = None
# Only start if container is active
if self.container:
self._start_persistent_shell()

# --- Private Helper Methods ---

def _copy_tools_to_container(self):
Expand Down
15 changes: 15 additions & 0 deletions trae_agent/agent/trae_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,18 @@ async def cleanup_mcp_clients(self) -> None:
# Use a generic server name for cleanup since we don't track which server each client is for
await client.cleanup("cleanup")
self.mcp_clients.clear()

@override
async def reset(self) -> None:
"""Reset the TraeAgent state."""
# Clean up MCP clients
await self.cleanup_mcp_clients()

# Clear MCP tools list
self.mcp_tools = []

# Restore tools to base tools (this removes accumulated MCP tools)
self._tools = list(self._base_tools)

# Call base reset (resets tool states and LLM history)
await super().reset()
4 changes: 4 additions & 0 deletions trae_agent/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ async def close(self):
"""Ensure proper tool resource deallocation before task completion."""
return None # Using "pass" will trigger a Ruff check error: B027

async def reset(self):
"""Reset the tool state. Override this method if the tool maintains state across tasks."""
return None


class ToolExecutor:
"""Tool executor that manages tool execution."""
Expand Down
5 changes: 5 additions & 0 deletions trae_agent/tools/bash_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,8 @@ async def close(self):
ret = await self._session.stop()
self._session = None
return ret

@override
async def reset(self):
"""Reset the tool state."""
await self.close()
6 changes: 6 additions & 0 deletions trae_agent/tools/sequential_thinking_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,9 @@ async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
error=f"Sequential thinking failed: {str(e)}\n\nDetails:\n{json.dumps(error_data, indent=2)}",
error_code=-1,
)

@override
async def reset(self) -> None:
"""Reset the tool state."""
self.thought_history = []
self.branches = {}
6 changes: 6 additions & 0 deletions trae_agent/utils/llm_clients/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def set_chat_history(self, messages: list[LLMMessage]) -> None:
"""Set the chat history."""
self.message_history = self.parse_messages(messages)

@override
def clear_history(self) -> None:
"""Clear the chat history."""
self.message_history = []
self.system_message = anthropic.NOT_GIVEN

def _create_anthropic_response(
self,
model_config: ModelConfig,
Expand Down
5 changes: 5 additions & 0 deletions trae_agent/utils/llm_clients/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def set_chat_history(self, messages: list[LLMMessage]) -> None:
"""Set the chat history."""
pass

@abstractmethod
def clear_history(self) -> None:
"""Clear the chat history."""
pass

@abstractmethod
def chat(
self,
Expand Down
6 changes: 6 additions & 0 deletions trae_agent/utils/llm_clients/google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def set_chat_history(self, messages: list[LLMMessage]) -> None:
"""Set the chat history."""
self.message_history, self.system_instruction = self.parse_messages(messages)

@override
def clear_history(self) -> None:
"""Clear the chat history."""
self.message_history = []
self.system_instruction = None

def _create_google_response(
self,
model_config: ModelConfig,
Expand Down
4 changes: 4 additions & 0 deletions trae_agent/utils/llm_clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def set_chat_history(self, messages: list[LLMMessage]) -> None:
"""Set the chat history."""
self.client.set_chat_history(messages)

def clear_history(self) -> None:
"""Clear the chat history."""
self.client.clear_history()

def chat(
self,
messages: list[LLMMessage],
Expand Down
5 changes: 5 additions & 0 deletions trae_agent/utils/llm_clients/ollama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def __init__(self, model_config: ModelConfig):
def set_chat_history(self, messages: list[LLMMessage]) -> None:
self.message_history = self.parse_messages(messages)

@override
def clear_history(self) -> None:
"""Clear the chat history."""
self.message_history = []

def _create_ollama_response(
self,
model_config: ModelConfig,
Expand Down
5 changes: 5 additions & 0 deletions trae_agent/utils/llm_clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def set_chat_history(self, messages: list[LLMMessage]) -> None:
"""Set the chat history."""
self.message_history = self.parse_messages(messages)

@override
def clear_history(self) -> None:
"""Clear the chat history."""
self.message_history = []

def _create_openai_response(
self,
api_call_input: ResponseInputParam,
Expand Down
5 changes: 5 additions & 0 deletions trae_agent/utils/llm_clients/openai_compatible_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def set_chat_history(self, messages: list[LLMMessage]) -> None:
"""Set the chat history."""
self.message_history = self.parse_messages(messages)

@override
def clear_history(self) -> None:
"""Clear the chat history."""
self.message_history = []

def _create_response(
self,
model_config: ModelConfig,
Expand Down