Skip to content

Commit e6f1f5d

Browse files
committed
universal rendering of any dict observation that contains only texts and images
1 parent 4fe4e48 commit e6f1f5d

File tree

1 file changed

+31
-32
lines changed

1 file changed

+31
-32
lines changed

src/agentlab/agents/react_toolcall_agent.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from functools import partial
66
from typing import Callable, Literal
77

8-
from litellm import completion, token_counter
8+
import numpy as np
9+
from litellm import completion
910
from litellm.types.utils import Message, ModelResponse
11+
from litellm.utils import token_counter
1012
from PIL import Image
1113
from termcolor import colored
1214

@@ -67,6 +69,10 @@ class AgentConfig:
6769
Provide a concise summary that preserves all information needed to continue the task."""
6870

6971

72+
def user_message(content: str | list[dict]) -> dict:
73+
return {"role": "user", "content": content}
74+
75+
7076
class ReactToolCallAgent:
7177
def __init__(
7278
self,
@@ -90,41 +96,34 @@ def obs_to_messages(self, obs: dict) -> list[dict]:
9096
"""
9197
Convert the observation dictionary into a list of chat messages for Lite LLM
9298
"""
99+
images = {k: v for k, v in obs.items() if isinstance(v, (Image.Image, np.ndarray))}
100+
texts = {k: v for k, v in obs.items() if k not in images and v is not None and v != ""}
93101
messages = []
94-
if obs.get("goal_object") and not self.last_tool_call_id:
102+
103+
if not self.last_tool_call_id and (goal_obj := texts.pop("goal_object", None)):
95104
# its a first observation when there are no tool_call_id, so include goal
96-
goal = obs["goal_object"][0]["text"]
97-
messages.append({"role": "user", "content": f"## Goal:\n{goal}"})
98-
text_obs = []
99-
if result := obs.get("action_result"):
100-
text_obs.append(f"## Action Result:\n{result}")
101-
if error := obs.get("last_action_error"):
102-
text_obs.append(f"## Action Error:\n{error}")
103-
if self.config.use_html and (html := obs.get("pruned_html")):
104-
text_obs.append(f"## HTML:\n{html}")
105-
if self.config.use_axtree and (axtree := obs.get("axtree_txt")):
106-
text_obs.append(f"## Accessibility Tree:\n{axtree}")
107-
content = "\n\n".join(text_obs)
108-
if content:
109-
if self.last_tool_call_id:
110-
message = {
111-
"role": "tool",
112-
"tool_call_id": self.last_tool_call_id,
113-
"content": content,
114-
}
115-
else:
116-
message = {"role": "user", "content": content}
117-
messages.append(message)
118-
if self.config.use_screenshot and (scr := obs.get("screenshot")):
119-
if isinstance(scr, Image.Image):
105+
goal = goal_obj[0]["text"]
106+
messages.append(user_message(f"Goal: {goal}"))
107+
108+
text = "\n\n".join([f"## {k}\n{v}" for k, v in texts.items()])
109+
if self.last_tool_call_id:
110+
message = {
111+
"role": "tool",
112+
"tool_call_id": self.last_tool_call_id,
113+
"content": text,
114+
}
115+
else:
116+
message = user_message(text)
117+
messages.append(message)
118+
119+
if self.config.use_screenshot:
120+
for caption, image in images.items():
120121
image_content = [
121-
{"type": "image_url", "image_url": {"url": image_to_png_base64_url(scr)}}
122+
{"type": "text", "text": caption},
123+
{"type": "image_url", "image_url": {"url": image_to_png_base64_url(image)}},
122124
]
123-
messages.append({"role": "user", "content": image_content})
124-
else:
125-
raise ValueError(
126-
f"Expected Image.Image in screenshot obs, got {type(obs['screenshot'])}"
127-
)
125+
messages.append(user_message(image_content))
126+
128127
return messages
129128

130129
def get_action(self, obs: dict) -> tuple[ToolCall, dict]:

0 commit comments

Comments
 (0)