Skip to content

Commit 3ca0d7b

Browse files
committed
Add ToolUseAgent class for managing tool interactions in chat.
1 parent eab3d41 commit 3ca0d7b

File tree

1 file changed

+233
-0
lines changed

1 file changed

+233
-0
lines changed

src/Agent/ToolUseAgent.php

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
<?php
2+
3+
namespace Hyperf\Odin\Agent;
4+
5+
use Hyperf\Odin\Api\OpenAI\Request\ToolDefinition;
6+
use Hyperf\Odin\Api\OpenAI\Response\ChatCompletionResponse;
7+
use Hyperf\Odin\Message\AssistantMessage;
8+
use Hyperf\Odin\Message\MessageInterface;
9+
use Hyperf\Odin\Message\ToolMessage;
10+
use Hyperf\Odin\Message\UserMessage;
11+
use Hyperf\Odin\Model\ModelInterface;
12+
use Hyperf\Odin\ModelMapper;
13+
use InvalidArgumentException;
14+
use Psr\Log\LoggerInterface;
15+
use function json_encode;
16+
use const JSON_UNESCAPED_UNICODE;
17+
18+
class ToolUseAgent
19+
{
20+
protected ModelInterface $model;
21+
22+
protected array $messages = [];
23+
24+
protected ?LoggerInterface $logger;
25+
26+
protected int $maxMessagesLimit = 50; // 最大消息数量限制
27+
28+
protected array $tools = [];
29+
30+
public function __construct(ModelMapper $modelMapper, ?LoggerInterface $logger = null)
31+
{
32+
$this->model = $modelMapper->getDefaultModel();
33+
$logger && $this->logger = $logger;
34+
}
35+
36+
public function setTools(array $tools): static
37+
{
38+
$this->validateTools($tools);
39+
$this->tools = $tools;
40+
return $this;
41+
}
42+
43+
public function chat(array|string|MessageInterface $messages, float $temperature = 0.2): ChatCompletionResponse
44+
{
45+
$this->handleMessages($messages);
46+
47+
// Trim messages to avoid overflow
48+
$this->trimMessages();
49+
50+
try {
51+
chat_call:
52+
$response = $this->model->chat(messages: $this->messages, temperature: $temperature, tools: $this->tools);
53+
$this->logger?->debug('FinishedReason: ' . $response->getFirstChoice()->getFinishReason());
54+
// 判断 finish_reason 是否为 length
55+
length_check:
56+
if ($response->getFirstChoice()->getFinishReason() === 'length') {
57+
$messageContent = $response->getFirstChoice()->getMessage()->getContent();
58+
59+
// 重新调用 llm 续写剩余内容
60+
$newResponse = $this->model->chat(messages: array_merge($this->messages, [new AssistantMessage($messageContent), new UserMessage('Continue')]), temperature: $temperature, tools: $this->tools);
61+
$newMessageContent = $newResponse->getFirstChoice()->getMessage()->getContent();
62+
63+
// 拼接续写后的内容
64+
$finalContent = $messageContent . $newMessageContent;
65+
$newResponse->getFirstChoice()->getMessage()->setContent($finalContent);
66+
$response = $newResponse;
67+
goto length_check;
68+
}
69+
70+
if ($response->getFirstChoice()->getMessage() instanceof AssistantMessage) {
71+
$this->messages[] = $response->getFirstChoice()->getMessage();
72+
}
73+
74+
// Log the response for each step
75+
$message = $response->getFirstChoice()->getMessage();
76+
if ($message->getContent()) {
77+
$this->logger?->debug('AI Message: ' . $response);
78+
}
79+
80+
// Check if the response indicates a tool call
81+
if ($response->getFirstChoice()->isFinishedByToolCall()) {
82+
// Call the appropriate tool
83+
$results = $this->callTool($response, $this->tools);
84+
foreach ($results as $callId => $result) {
85+
// Build Tool Message
86+
$this->messages[] = new ToolMessage($result, $callId);
87+
}
88+
// 检查是否所有的 ToolCall 都有对应 CallID 的 ToolMessage
89+
$toolCalls = $response->getFirstChoice()->getMessage()->getToolCalls();
90+
$toolCallIds = array_map(fn($toolCall) => $toolCall->getId(), $toolCalls);
91+
$toolMessageIds = [];
92+
foreach ($this->messages as $message) {
93+
if ($message instanceof ToolMessage) {
94+
$toolMessageIds[] = $message->getToolCallId();
95+
}
96+
}
97+
$missingToolCallIds = array_diff($toolCallIds, $toolMessageIds);
98+
if (!empty($missingToolCallIds)) {
99+
// 构造空的 ToolMessage 加到 $this->messages 中
100+
foreach ($missingToolCallIds as $missingToolCallId) {
101+
$this->messages[] = new ToolMessage('No Result.', $missingToolCallId);
102+
}
103+
}
104+
// Trim messages after tool call
105+
$this->trimMessages();
106+
goto chat_call;
107+
}
108+
} catch (\Exception $e) {
109+
$errorMessage = is_array($e->getMessage()) ? json_encode($e->getMessage()) : $e->getMessage();
110+
$this->logger?->error('Error during chat: ' . $errorMessage);
111+
throw new \RuntimeException('Error during chat: ' . $errorMessage, previous: $e);
112+
}
113+
return $response;
114+
}
115+
116+
protected function callTool(ChatCompletionResponse $response, array $tools): array
117+
{
118+
$message = $response->getFirstChoice()->getMessage();
119+
if (! $message instanceof AssistantMessage) {
120+
return [];
121+
}
122+
$result = [];
123+
$toolCalls = $message->getToolCalls();
124+
foreach ($toolCalls as $toolCall) {
125+
// Find the tool that matches the tool call
126+
foreach ($tools as $tool) {
127+
if ($tool instanceof ToolDefinition) {
128+
if ($tool->getName() === $toolCall->getName()) {
129+
// Execute the tool
130+
$callToolResult = call_user_func($tool->getToolHandler(), $toolCall->getArguments());
131+
$result[$toolCall->getId()] = $callToolResult;
132+
133+
// Log the tool call result
134+
$this->logger?->debug(sprintf('Tool %s calling with arguments: %s', $tool->getName(), json_encode($toolCall->getArguments(), JSON_UNESCAPED_UNICODE)));
135+
}
136+
}
137+
}
138+
}
139+
return $result;
140+
}
141+
142+
public function handleMessages(array|string|MessageInterface $messages): array
143+
{
144+
if (is_string($messages)) {
145+
$messages = new UserMessage($messages);
146+
$this->messages[] = $messages;
147+
} elseif ($messages instanceof MessageInterface) {
148+
$this->messages[] = $messages;
149+
} elseif (is_array($messages)) {
150+
foreach ($messages as $message) {
151+
if (! $message instanceof MessageInterface) {
152+
throw new InvalidArgumentException('The message must be an instance of MessageInterface.');
153+
}
154+
}
155+
$this->messages = array_merge($this->messages, $messages);
156+
}
157+
return $this->messages;
158+
}
159+
160+
protected function validateTools(array $tools): void
161+
{
162+
foreach ($tools as $tool) {
163+
if (! $tool instanceof ToolDefinition) {
164+
throw new InvalidArgumentException('The tool must be an instance of ToolDefinition.');
165+
}
166+
}
167+
}
168+
169+
public function trimMessages(): void
170+
{
171+
if (count($this->messages) > $this->maxMessagesLimit) {
172+
$firstUserMessage = null;
173+
foreach ($this->messages as $index => $message) {
174+
if ($message instanceof UserMessage) {
175+
$firstUserMessage = $message;
176+
break;
177+
}
178+
}
179+
180+
$deleteMessages = [];
181+
foreach ($this->messages as $index => $message) {
182+
if ($message instanceof AssistantMessage) {
183+
$toolCalls = $message->getToolCalls();
184+
if (!empty($toolCalls)) {
185+
for ($i = $index + 1; $i < count($this->messages); $i++) {
186+
$nextMessage = $this->messages[$i];
187+
if ($nextMessage instanceof ToolMessage) {
188+
foreach ($toolCalls as $toolCall) {
189+
if ($nextMessage->getToolCallId() === $toolCall->getId()) {
190+
$deleteMessages[] = $nextMessage;
191+
break;
192+
}
193+
}
194+
}
195+
}
196+
$deleteMessages[] = $message;
197+
}
198+
}
199+
}
200+
201+
foreach ($deleteMessages as $deleteMessage) {
202+
$key = array_search($deleteMessage, $this->messages);
203+
if ($key !== false) {
204+
unset($this->messages[$key]);
205+
}
206+
}
207+
208+
$this->messages = array_values($this->messages);
209+
210+
$this->messages = array_slice($this->messages, -$this->getMaxMessagesLimit() + 1);
211+
212+
if ($firstUserMessage) {
213+
array_unshift($this->messages, $firstUserMessage);
214+
}
215+
}
216+
}
217+
218+
public function getMessages(): array
219+
{
220+
return $this->messages;
221+
}
222+
223+
public function getMaxMessagesLimit(): int
224+
{
225+
return $this->maxMessagesLimit;
226+
}
227+
228+
public function setMaxMessagesLimit(int $maxMessagesLimit): static
229+
{
230+
$this->maxMessagesLimit = $maxMessagesLimit;
231+
return $this;
232+
}
233+
}

0 commit comments

Comments
 (0)