Skip to content

Support response format&llmstudio #1704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions metagpt/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SerializationMixin,
TestingContext,
)
from metagpt.utils.format import ResponseFormat
from metagpt.utils.project_repo import ProjectRepo


Expand Down Expand Up @@ -103,9 +104,9 @@ def __str__(self):
def __repr__(self):
return self.__str__()

async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str:
async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None, response_format: Optional[ResponseFormat] = None) -> str:
"""Append default prefix"""
return await self.llm.aask(prompt, system_msgs)
return await self.llm.aask(prompt, system_msgs, response_format=response_format)

async def _run_action_node(self, *args, **kwargs):
"""Run action node"""
Expand Down
1 change: 1 addition & 0 deletions metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LLMType(Enum):
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
LLMSTUDIO = "llmstudio"

def __missing__(self, key):
return self.OPENAI
Expand Down
113 changes: 105 additions & 8 deletions metagpt/ext/werewolf/actions/common_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,50 @@
# @Desc :

import json
import re

from tenacity import retry, stop_after_attempt, wait_fixed

from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.utils.common import parse_json_code_block

from metagpt.utils.format import ResponseFormat, JsonResponseFormat

def log_and_parse_json(name: str, rsp: str) -> dict:
rsp = rsp.replace("\n", " ")
logger.debug(f"{name} result: {rsp}")
json_blocks = parse_json_code_block(rsp)
rsp_json = json.loads(json_blocks[0])
try:
rsp_json = json.loads(json_blocks[0])
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON: {rsp}, {e}")
rsp_json = json.loads(convert_to_json_if_reason(rsp))
return rsp_json


def convert_to_json_if_reason(rsp: str) -> str:
"""deepseek response format with <think>...</think> **RESPONSE:** {answer} """
if re.search(r"<think\w*>", rsp, re.IGNORECASE) and re.search(r"\*?\*?(RESPONSE|ANSWER)\*?\*?:", rsp, re.IGNORECASE):
answer = re.split(r"\*?\*?(RESPONSE|ANSWER)\*?\*?:", rsp, flags=re.IGNORECASE)[-1]
think = re.search(r"(?:<think\w*>)(.*?)(?:</think\w*>)", rsp, re.DOTALL)
if think:
# 处理think中的特殊字符
thought = think.group(1).replace('"', "'")
response = re.split('\s+', answer)[-1]
return f'{{"THOUGHT": "{thought}", "RESPONSE": "{response}"}}'
# 查找json字符串
elif re.search(r"\{(.*?)\}", rsp, re.DOTALL):
return '{' + re.findall(r"\{(.*?)\}", rsp, re.DOTALL)[-1] + '}'
# 找到最后出现的PlayerX字符串并返回为RESPONSE
response = re.findall(r"(Player\s*\d+|pass|none)", rsp, re.IGNORECASE)
if response:
return f'{{"RESPONSE": "{response[-1]}"}}'
else:
# 删除<think>...</think>
rsp = re.sub(r"(?:<think\w*>)(.*?)(?:</think\w*>)", "", rsp)
return f'{{"RESPONSE": "{rsp}"}}'


class Speak(Action):
"""Action: Any speak action in a game"""

Expand Down Expand Up @@ -53,6 +81,17 @@ class Speak(Action):

name: str = "Speak"

def _construct_response_format(self, **kwargs) -> ResponseFormat:
response_format = JsonResponseFormat()
response_format.add_property("ROLE", "Your role, in this case, __profile__", "string", required=True)
response_format.add_property("PLAYER_NAME", "Your name, in this case, __name__", "string", required=True)
response_format.add_property("LIVING_PLAYERS", "List the players who is alive based on moderator's latest instruction. Return a json LIST datatype.", "array", required=True)
response_format.add_property("THOUGHTS", "Choose one living player from `LIVING_PLAYERS` to __action__ this night. Return the reason why you choose to __action__ this player. If you observe nothing at first night, DONT imagine unexisting player actions! If you find similar situation in `PAST_EXPERIENCES`, you may draw lessons from them to refine your strategy and take better actions. Give your step-by-step thought process, you should think no more than 3 steps. For example: My step-by-step thought process:...", "string", required=True)
response_format.add_property("RESPONSE", "As a __profile__, you should choose one living player from `LIVING_PLAYERS` to __action__ this night according to the THOUGHTS you have just now. Return the player name ONLY.", "string", required=True)
response_format.update_describe_with_dict(kwargs)
return response_format


@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
async def run(
self,
Expand All @@ -73,10 +112,22 @@ async def run(
.replace("__experiences__", experiences)
)

rsp = await self._aask(prompt)
response_formatter = self._construct_response_format(
__profile__=profile, __name__=name, __context__=context, __latest_instruction__=latest_instruction,
__strategy__=self.STRATEGY, __reflection__=reflection, __experiences__=experiences
)
rsp = await self._aask(prompt, response_format=response_formatter)
rsp_json = log_and_parse_json(self.name, rsp)

return rsp_json["RESPONSE"]
try:
return rsp_json["RESPONSE"]
except KeyError as e:
# For some not-so-smart models, there will be no RESPONSE key, so here we look for the key value corresponding to the most similar key
ret = response_formatter.get_most_likely_key(rsp_json, "RESPONSE")
if ret:
return ret
else:
raise e


class NighttimeWhispers(Action):
Expand Down Expand Up @@ -183,13 +234,27 @@ def _update_prompt_json(
# one can modify the prompt_json dictionary here
return prompt_json

def _construct_response_format(self, **kwargs) -> ResponseFormat:
response_format = JsonResponseFormat()
response_format.add_property("ROLE", "Your role, in this case, __profile__", "string", required=True)
response_format.add_property("PLAYER_NAME", "Your name, in this case, __name__", "string", required=True)
response_format.add_property("LIVING_PLAYERS", "List the players who is alive based on moderator's latest instruction. Return a json LIST datatype.", "array", required=True)
response_format.add_property("THOUGHTS", "Choose one living player from `LIVING_PLAYERS` to __action__ this night. Return the reason why you choose to __action__ this player. If you observe nothing at first night, DONT imagine unexisting player actions! If you find similar situation in `PAST_EXPERIENCES`, you may draw lessons from them to refine your strategy and take better actions. Give your step-by-step thought process, you should think no more than 3 steps. For example: My step-by-step thought process:...", "string", required=True)
response_format.add_property("RESPONSE", "As a __profile__, you should choose one living player from `LIVING_PLAYERS` to __action__ this night according to the THOUGHTS you have just now. Return the player name ONLY.", "string", required=True)
response_format.update_describe_with_dict(kwargs)
return response_format


@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
async def run(self, context: str, profile: str, name: str, reflection: str = "", experiences: str = ""):
prompt = self._construct_prompt_json(
role_profile=profile, role_name=name, context=context, reflection=reflection, experiences=experiences
)

rsp = await self._aask(prompt)
rsp = await self._aask(prompt, response_format=self._construct_response_format(
__profile__=profile, __name__=name, __context__=context, __reflection__=reflection,
__experiences__=experiences, __action__=self.name, __strategy_=self.STRATEGY
))
rsp_json = log_and_parse_json(self.name, rsp)

return f"{self.name} " + rsp_json["RESPONSE"]
Expand Down Expand Up @@ -225,7 +290,28 @@ class Reflect(Action):

name: str = "Reflect"

@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
def _construct_response_format(self, **kwargs) -> ResponseFormat:
response_format = JsonResponseFormat()
response_format.add_property("ROLE", "Your role, in this case, __profile__", "string", required=True)
response_format.add_property("PLAYER_NAME", "Your name, in this case, __name__", "string", required=True)
response_format.add_property("GAME_STATES", """You are about to follow `MODERATOR_INSTRUCTION`, but before taking any action, analyze each player, including the living and the dead, and summarize the game states.
For each player, your reflection should be a ONE-LINE json covering the following dimension, return a LIST of jsons (return an empty LIST for the first night):
[
{"TARGET": "the player you will analyze, if the player is yourself or your werewolf partner, indicate it" ,"STATUS": "living or dead, if dead, how was he/she possibly killed?", "CLAIMED_ROLE": "claims a role or not, if so, what role, any contradiction to others? If there is no claim, return 'None'", "SIDE_WITH": "sides with which players? If none, return 'None'", "ACCUSE": "accuses which players? If none, return 'None'"}
,{...}
,...
]""", "array", required=True)
response_format.add_property("REFLECTION", """Based on the whole `GAME_STATES`, return a json (return an empty string for the first night):
{
"Player1": "the true role (werewolf / special role / villager, living or dead) you infer about him/her, and why is this role? If the player is yourself or your werewolf partner, indicate it."
,...
,"Player7": "the true role (werewolf / special role / villager, living or dead) you infer about him/her, and why is this role? If the player is yourself or your werewolf partner, indicate it."
,"GAME_STATE_SUMMARIZATION": "summarize the current situation from your standpoint in one sentence, your summarization should catch the most important information from your reflection, such as conflicts, number of living werewolves, special roles, and villagers."
}""", "json_object", required=True)
response_format.update_describe_with_dict(kwargs)
return response_format

@retry(stop=stop_after_attempt(4), wait=wait_fixed(1), retry_error_cls={KeyError, json.JSONDecodeError})
async def run(self, profile: str, name: str, context: str, latest_instruction: str):
prompt = (
self.PROMPT_TEMPLATE.replace("__context__", context)
Expand All @@ -234,7 +320,18 @@ async def run(self, profile: str, name: str, context: str, latest_instruction: s
.replace("__latest_instruction__", latest_instruction)
)

rsp = await self._aask(prompt)
response_formatter = self._construct_response_format(
__profile__=profile, __name__=name, __context__=context, __latest_instruction__=latest_instruction
)
rsp = await self._aask(prompt, response_format=response_formatter)
rsp_json = log_and_parse_json(self.name, rsp)

return json.dumps(rsp_json["REFLECTION"])
try:
rsp_reflection = rsp_json["REFLECTION"]
except KeyError:
rsp_reflection = response_formatter.get_most_likely_key(rsp_json, "REFLECTION")

if isinstance(rsp_reflection, str):
return rsp_reflection
else:
return json.dumps(rsp_reflection)
3 changes: 2 additions & 1 deletion metagpt/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock_api import BedrockLLM
from metagpt.provider.ark_api import ArkLLM

from metagpt.provider.llmstudio_api import LLMStudioLLM
__all__ = [
"GeminiLLM",
"OpenAILLM",
Expand All @@ -34,4 +34,5 @@
"AnthropicLLM",
"BedrockLLM",
"ArkLLM",
"LLMStudioLLM",
]
10 changes: 6 additions & 4 deletions metagpt/provider/anthropic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.format import ResponseFormat
from typing import Optional


@register_provider([LLMType.ANTHROPIC, LLMType.CLAUDE])
Expand Down Expand Up @@ -42,15 +44,15 @@ def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool
def get_choice_text(self, resp: Message) -> str:
return resp.content[0].text

async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None) -> Message:
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
self._update_costs(resp.usage, self.model)
return resp

async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None) -> Message:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout), response_format=response_format)

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None) -> str:
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = Usage(input_tokens=0, output_tokens=0)
Expand Down
5 changes: 3 additions & 2 deletions metagpt/provider/ark_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from metagpt.logs import log_llm_stream
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.format import ResponseFormat
from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS


Expand Down Expand Up @@ -71,7 +72,7 @@ def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_
if self.pricing_plan in self.cost_manager.token_costs:
super()._update_costs(usage, self.pricing_plan, local_calc_usage)

async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
Expand All @@ -92,7 +93,7 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI
self._update_costs(usage, chunk.model)
return full_reply_content

async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage, rsp.model)
Expand Down
17 changes: 9 additions & 8 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs

from metagpt.utils.format import ResponseFormat

class BaseLLM(ABC):
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
Expand Down Expand Up @@ -132,6 +132,7 @@ async def aask(
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
response_format: Optional[ResponseFormat] = None,
stream=None,
) -> str:
if system_msgs:
Expand All @@ -149,7 +150,7 @@ async def aask(
if stream is None:
stream = self.config.stream
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout), response_format=response_format)
return rsp

def _extract_assistant_rsp(self, context):
Expand All @@ -169,11 +170,11 @@ async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE
raise NotImplementedError

@abstractmethod
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None):
"""_achat_completion implemented by inherited class"""

@abstractmethod
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None):
"""Asynchronous version of completion
All GPTAPIs are required to provide the standard OpenAI completion interface
[
Expand All @@ -184,7 +185,7 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""

@abstractmethod
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None) -> str:
"""_achat_completion_stream implemented by inherited class"""

@retry(
Expand All @@ -195,12 +196,12 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT
self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT, response_format: Optional[ResponseFormat] = None
) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
if stream:
return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout))
resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout), response_format=response_format)
resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout), response_format=response_format)
return self.get_choice_text(resp)

def get_choice_text(self, rsp: dict) -> str:
Expand Down
Loading
Loading