Skip to content

Commit

Permalink
use xml syntax in the prompt for chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
elisalimli committed May 1, 2024
1 parent 64a14c1 commit 19f1e9a
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 105 deletions.
221 changes: 117 additions & 104 deletions libs/superagent/app/agents/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def call_tool(
break

if not tool_to_call:
msg = f"Function {function.name} not found in tools, avaliable tool names: {', '.join([tool.name for tool in tools])}"
msg = f"Tool {function.name} not found in tools, avaliable tool names: {', '.join([tool.name for tool in tools])}"
logger.error(msg)
return ToolCallResponse(
action_log=action_log,
Expand Down Expand Up @@ -193,11 +193,11 @@ def prompt(self):
inital_token_usage=len(prompt),
)
if len(messages) > 0:
prompt += "\n\n Previous messages: \n"
prompt += "\n\n Here's the previous conversation: <chat_history> \n"
for message in messages:
prompt += (
f"""{message.type.value.capitalize()}: {message.content}\n\n"""
)
prompt += f"""<{message.type.value}> {message.content} </{message.type.value}>\n"""
prompt += " </chat_history> \n"

return prompt

@property
Expand Down Expand Up @@ -267,7 +267,7 @@ async def _execute_tools(
)
new_message = {
"role": "tool",
"name": tool_call.get("function").get("name"),
"name": tool_call.function.name,
"content": tool_call_res.result,
}
if tool_call.id:
Expand Down Expand Up @@ -402,7 +402,7 @@ async def _acompletion(self, depth: int = 0, **kwargs) -> Any:
self.memory.aadd_message(
message=BaseMessage(
type=MessageType.TOOL_CALL,
content=json.dumps(tool_call),
content=tool_call.json(),
)
)
for tool_call in tool_calls
Expand Down Expand Up @@ -512,115 +512,128 @@ def messages(self):
]

async def ainvoke(self, input, *_, **kwargs):
self.input = input
tool_results = []
output = ""
try:
self.input = input
tool_results = []

await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.HUMAN,
content=self.input,
)
)
if self.enable_streaming:
self._set_streaming_callback(
kwargs.get("config", {}).get("callbacks", [])
)

if self.enable_streaming:
self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", []))
if len(self.tools) > 0:
openai_llm = await prisma.llm.find_first(
where={
"provider": LLMProvider.OPENAI.value,
"apiUserId": self.agent_data.apiUserId,
}
)
if openai_llm:
openai_api_key = openai_llm.apiKey
else:
openai_api_key = config("OPENAI_API_KEY")
logger.warn(
"OpenAI API Key not found in database, using environment variable"
)

if len(self.tools) > 0:
openai_llm = await prisma.llm.find_first(
where={
"provider": LLMProvider.OPENAI.value,
"apiUserId": self.agent_data.apiUserId,
}
)
if openai_llm:
openai_api_key = openai_llm.apiKey
else:
openai_api_key = config("OPENAI_API_KEY")
logger.warn(
"OpenAI API Key not found in database, using environment variable"
res = await acompletion(
api_key=openai_api_key,
model="gpt-3.5-turbo-0125",
messages=self.messages_function_calling,
tools=self.tools,
stream=False,
)

res = await acompletion(
api_key=openai_api_key,
model="gpt-3.5-turbo-0125",
messages=self.messages_function_calling,
tools=self.tools,
stream=False,
)
tool_calls = []
if (
hasattr(res.choices[0].message, "tool_calls")
and res.choices[0].message.tool_calls
):
tool_calls = res.choices[0].message.tool_calls

for tool_call in tool_calls:
tool_call_res = await call_tool(
agent_data=self.agent_data,
session_id=self.session_id,
function=tool_call.function,
)

tool_calls = res.choices[0].message.get("tool_calls", [])
for tool_call in tool_calls:
tool_call_res = await call_tool(
agent_data=self.agent_data,
session_id=self.session_id,
function=tool_call.function,
)
# TODO: handle the failure in tool call case
# if not intermediate_step.success:
# self.memory.add_message(
# message=BaseMessage(
# type=MessageType.TOOL_RESULT,
# content=intermediate_step.result,
# )
# )

tool_results.append(
(tool_call_res.action_log, tool_call_res.result)
)

# TODO: handle the failure in tool call case
# if not intermediate_step.success:
# self.memory.add_message(
# message=BaseMessage(
# type=MessageType.TOOL_RESULT,
# content=intermediate_step.result,
# )
# )

tool_results.append((tool_call_res.action_log, tool_call_res.result))

if tool_call_res.return_direct:
if self.enable_streaming:
await self._stream_text_by_lines(tool_call_res.result)
self.streaming_callback.done.set()

return {
"intermediate_steps": tool_results,
"input": self.input,
"output": tool_call_res.result,
}
if tool_call_res.return_direct:
if self.enable_streaming:
await self._stream_text_by_lines(tool_call_res.result)
self.streaming_callback.done.set()

output = tool_call_res.result

return {
"intermediate_steps": tool_results,
"input": self.input,
"output": output,
}

if len(tool_results) > 0:
INPUT_TEMPLATE = "{input}\n Context: {context}\n"
self.input = INPUT_TEMPLATE.format(
input=self.input,
context="\n\n".join(
[tool_response for (_, tool_response) in tool_results]
),
)

if len(tool_results) > 0:
INPUT_TEMPLATE = "{input}\n Context: {context}\n"
self.input = INPUT_TEMPLATE.format(
input=self.input,
context="\n\n".join(
[tool_response for (_, tool_response) in tool_results]
),
params = self.llm_data.params.dict(exclude_unset=True)
second_res = await acompletion(
api_key=self.llm_data.llm.apiKey,
model=self.llm_data.model,
messages=self.messages,
stream=self.enable_streaming,
**params,
)

params = self.llm_data.params.dict(exclude_unset=True)
second_res = await acompletion(
api_key=self.llm_data.llm.apiKey,
model=self.llm_data.model,
messages=self.messages,
stream=self.enable_streaming,
**params,
)

output = ""
if self.enable_streaming:
await self.streaming_callback.on_llm_start()
second_res = cast(CustomStreamWrapper, second_res)
if self.enable_streaming:
await self.streaming_callback.on_llm_start()
second_res = cast(CustomStreamWrapper, second_res)

async for chunk in second_res:
token = chunk.choices[0].delta.content
if token:
output += token
await self.streaming_callback.on_llm_new_token(token)
async for chunk in second_res:
token = chunk.choices[0].delta.content
if token:
output += token
await self.streaming_callback.on_llm_new_token(token)

self.streaming_callback.done.set()
else:
second_res = cast(ModelResponse, second_res)
output = second_res.choices[0].message.content
self.streaming_callback.done.set()
else:
second_res = cast(ModelResponse, second_res)
output = second_res.choices[0].message.content

await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.AI,
content=output,
return {
"intermediate_steps": tool_results,
"input": self.input,
"output": output,
}
finally:
await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.HUMAN,
content=self.input,
)
)
)

return {
"intermediate_steps": tool_results,
"input": self.input,
"output": output,
}
await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.AI,
content=output,
)
)
2 changes: 1 addition & 1 deletion libs/superagent/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 19f1e9a

Please sign in to comment.