Skip to content

Simplify RoleZero code #1786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion metagpt/roles/di/engineer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser, awrite
from metagpt.utils.report import EditorReporter
from metagpt.utils.role_zero_utils import get_plan_status


@register_tool(include_functions=["write_new_code"])
Expand Down Expand Up @@ -117,7 +118,7 @@ async def write_new_code(self, path: str, file_description: str = "") -> str:
"""
# If the path is not absolute, try to fix it with the editor's working directory.
path = self.editor._try_fix_path(path)
plan_status, _ = self._get_plan_status()
plan_status, _ = get_plan_status(planner=self.planner)
prompt = WRITE_CODE_PROMPT.format(
user_requirement=self.planner.plan.goal,
plan_status=plan_status,
Expand Down
194 changes: 29 additions & 165 deletions metagpt/roles/di/role_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,28 @@
import re
import traceback
from datetime import datetime
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
from typing import Annotated, Callable, Literal, Optional, Tuple

from pydantic import Field, model_validator

from metagpt.actions import Action, UserRequirement
from metagpt.actions.di.run_command import RunCommand
from metagpt.actions.search_enhanced_qa import SearchEnhancedQA
from metagpt.const import IMAGES
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers import RoleZeroSerializer
from metagpt.logs import logger
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
from metagpt.prompts.di.role_zero import (
ASK_HUMAN_COMMAND,
ASK_HUMAN_GUIDANCE_FORMAT,
CMD_PROMPT,
DETECT_LANGUAGE_PROMPT,
END_COMMAND,
JSON_REPAIR_PROMPT,
QUICK_RESPONSE_SYSTEM_PROMPT,
QUICK_THINK_EXAMPLES,
QUICK_THINK_PROMPT,
QUICK_THINK_SYSTEM_PROMPT,
QUICK_THINK_TAG,
REGENERATE_PROMPT,
REPORT_TO_HUMAN_PROMPT,
ROLE_INSTRUCTION,
SUMMARY_PROBLEM_WHEN_DUPLICATE,
SUMMARY_PROMPT,
SYSTEM_PROMPT,
)
Expand All @@ -45,13 +38,17 @@
from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
from metagpt.utils.repair_llm_raw_output import (
RepairType,
repair_escape_error,
repair_llm_raw_output,
)
from metagpt.utils.common import any_to_str
from metagpt.utils.report import ThoughtReporter
from metagpt.utils.role_zero_utils import (
check_duplicates,
format_terminal_output,
get_plan_status,
parse_browser_actions,
parse_commands,
parse_editor_result,
parse_images,
)


@register_tool(include_functions=["ask_human", "reply_to_human"])
Expand Down Expand Up @@ -216,7 +213,7 @@ async def _think(self) -> bool:
example = self._retrieve_experience()

### 2. Plan Status ###
plan_status, current_task = self._get_plan_status()
plan_status, current_task = get_plan_status(planner=self.planner)

### 3. Tool/Command Info ###
tools = await self.tool_recommender.recommend_tools()
Expand All @@ -242,9 +239,9 @@ async def _think(self) -> bool:

### Recent Observation ###
memory = self.rc.memory.get(self.memory_k)
memory = await self.parse_browser_actions(memory)
memory = await self.parse_editor_result(memory)
memory = self.parse_images(memory)
memory = await parse_browser_actions(memory, browser=self.browser)
memory = await parse_editor_result(memory)
memory = await parse_images(memory, llm=self.llm)

req = self.llm.format_msg(memory + [UserMessage(content=prompt)])
state_data = dict(
Expand All @@ -255,7 +252,16 @@ async def _think(self) -> bool:
async with ThoughtReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "react"})
self.command_rsp = await self.llm_cached_aask(req=req, system_msgs=[system_prompt], state_data=state_data)
self.command_rsp = await self._check_duplicates(req, self.command_rsp)

rsp_hist = [mem.content for mem in self.rc.memory.get()]
self.command_rsp = await check_duplicates(
req=req,
command_rsp=self.command_rsp,
rsp_hist=rsp_hist,
llm=self.llm,
respond_language=self.respond_language,
)

return True

@exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer())
Expand All @@ -267,44 +273,6 @@ async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str], **kw
"""
return await self.llm.aask(req, system_msgs=system_msgs)

async def parse_browser_actions(self, memory: list[Message]) -> list[Message]:
if not self.browser.is_empty_page:
pattern = re.compile(r"Command Browser\.(\w+) executed")
for index, msg in zip(range(len(memory), 0, -1), memory[::-1]):
if pattern.search(msg.content):
memory.insert(index, UserMessage(cause_by="browser", content=await self.browser.view()))
break
return memory

async def parse_editor_result(self, memory: list[Message], keep_latest_count=5) -> list[Message]:
"""Retain the latest result and remove outdated editor results."""
pattern = re.compile(r"Command Editor\.(\w+?) executed")
new_memory = []
i = 0
for msg in reversed(memory):
matches = pattern.findall(msg.content)
if matches:
i += 1
if i > keep_latest_count:
new_content = msg.content[: msg.content.find("Command Editor")]
new_content += "\n".join([f"Command Editor.{match} executed." for match in matches])
msg = UserMessage(content=new_content)
new_memory.append(msg)
# Reverse the new memory list so the latest message is at the end
new_memory.reverse()
return new_memory

def parse_images(self, memory: list[Message]) -> list[Message]:
if not self.llm.support_image_input():
return memory
for msg in memory:
if IMAGES in msg.metadata or msg.role != "user":
continue
images = extract_and_encode_images(msg.content)
if images:
msg.add_metadata(IMAGES, images)
return memory

def _get_prefix(self) -> str:
time_info = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
return super()._get_prefix() + f" The current time is {time_info}."
Expand All @@ -313,7 +281,9 @@ async def _act(self) -> Message:
if self.use_fixed_sop:
return await super()._act()

commands, ok, self.command_rsp = await self._parse_commands(self.command_rsp)
commands, ok, self.command_rsp = await parse_commands(
command_rsp=self.command_rsp, llm=self.llm, exclusive_tool_commands=self.exclusive_tool_commands
)
self.rc.memory.add(AIMessage(content=self.command_rsp))
if not ok:
error_msg = commands
Expand Down Expand Up @@ -412,85 +382,6 @@ async def _quick_think(self) -> Tuple[Message, str]:

return rsp_msg, intent_result

async def _check_duplicates(self, req: list[dict], command_rsp: str, check_window: int = 10):
past_rsp = [mem.content for mem in self.rc.memory.get(check_window)]
if command_rsp in past_rsp and '"command_name": "end"' not in command_rsp:
# Normal response with thought contents are highly unlikely to reproduce
# If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content
# In this case, ask human for help and regenerate
# TODO: switch to llm_cached_aask

# Hard rule to ask human for help
if past_rsp.count(command_rsp) >= 3:
if '"command_name": "Plan.finish_current_task",' in command_rsp:
# Detect the duplicate of the 'Plan.finish_current_task' command, and use the 'end' command to finish the task.
logger.warning(f"Duplicate response detected: {command_rsp}")
return END_COMMAND
problem = await self.llm.aask(
req + [UserMessage(content=SUMMARY_PROBLEM_WHEN_DUPLICATE.format(language=self.respond_language))]
)
ASK_HUMAN_COMMAND[0]["args"]["question"] = ASK_HUMAN_GUIDANCE_FORMAT.format(problem=problem).strip()
ask_human_command = "```json\n" + json.dumps(ASK_HUMAN_COMMAND, indent=4, ensure_ascii=False) + "\n```"
return ask_human_command
# Try correction by self
logger.warning(f"Duplicate response detected: {command_rsp}")
regenerate_req = req + [UserMessage(content=REGENERATE_PROMPT)]
regenerate_req = self.llm.format_msg(regenerate_req)
command_rsp = await self.llm.aask(regenerate_req)
return command_rsp

async def _parse_commands(self, command_rsp) -> Tuple[List[Dict], bool]:
"""Retrieves commands from the Large Language Model (LLM).
This function attempts to retrieve a list of commands from the LLM by
processing the response (`self.command_rsp`). It handles potential errors
during parsing and LLM response formats.
Returns:
A tuple containing:
- A boolean flag indicating success (True) or failure (False).
"""
try:
commands = CodeParser.parse_code(block=None, lang="json", text=command_rsp)
if commands.endswith("]") and not commands.startswith("["):
commands = "[" + commands
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON for: {command_rsp}. Trying to repair...")
commands = await self.llm.aask(
msg=JSON_REPAIR_PROMPT.format(json_data=command_rsp, json_decode_error=str(e))
)
try:
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except json.JSONDecodeError:
# repair escape error of code and math
commands = CodeParser.parse_code(block=None, lang="json", text=command_rsp)
new_command = repair_escape_error(commands)
commands = json.loads(
repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON)
)
except Exception as e:
tb = traceback.format_exc()
print(tb)
error_msg = str(e)
return error_msg, False, command_rsp

# 为了对LLM不按格式生成进行容错
if isinstance(commands, dict):
commands = commands["commands"] if "commands" in commands else [commands]

# Set the exclusive command flag to False.
command_flag = [command["command_name"] not in self.exclusive_tool_commands for command in commands]
if command_flag.count(False) > 1:
# Keep only the first exclusive command
index_of_first_exclusive = command_flag.index(False)
commands = commands[: index_of_first_exclusive + 1]
command_rsp = "```json\n" + json.dumps(commands, indent=4, ensure_ascii=False) + "\n```"
logger.info(
"exclusive command more than one in current command list. change the command list.\n" + command_rsp
)
return commands, True, command_rsp

async def _run_commands(self, commands) -> str:
outputs = []
for cmd in commands:
Expand Down Expand Up @@ -552,36 +443,9 @@ async def _run_special_command(self, cmd) -> str:
elif cmd["command_name"] == "Terminal.run_command":
tool_obj = self.tool_execution_map[cmd["command_name"]]
tool_output = await tool_obj(**cmd["args"])
if len(tool_output) <= 10:
command_output += (
f"\n[command]: {cmd['args']['cmd']} \n[command output] : {tool_output} (pay attention to this.)"
)
else:
command_output += f"\n[command]: {cmd['args']['cmd']} \n[command output] : {tool_output}"

command_output = format_terminal_output(cmd=cmd, raw_output=tool_output)
return command_output

def _get_plan_status(self) -> Tuple[str, str]:
plan_status = self.planner.plan.model_dump(include=["goal", "tasks"])
current_task = (
self.planner.plan.current_task.model_dump(exclude=["code", "result", "is_success"])
if self.planner.plan.current_task
else ""
)
# format plan status
# Example:
# [GOAL] create a 2048 game
# [TASK_ID 1] (finished) Create a Product Requirement Document (PRD) for the 2048 game. This task depends on tasks[]. [Assign to Alice]
# [TASK_ID 2] ( ) Design the system architecture for the 2048 game. This task depends on tasks[1]. [Assign to Bob]
formatted_plan_status = f"[GOAL] {plan_status['goal']}\n"
if len(plan_status["tasks"]) > 0:
formatted_plan_status += "[Plan]\n"
for task in plan_status["tasks"]:
formatted_plan_status += f"[TASK_ID {task['task_id']}] ({'finished' if task['is_finished'] else ' '}){task['instruction']} This task depends on tasks{task['dependent_task_ids']}. [Assign to {task['assignee']}]\n"
else:
formatted_plan_status += "No Plan \n"
return formatted_plan_status, current_task

def _retrieve_experience(self) -> str:
"""Default implementation of experience retrieval. Can be overwritten in subclasses."""
context = [str(msg) for msg in self.rc.memory.get(self.memory_k)]
Expand Down
Loading
Loading