55from functools import partial
66from typing import Callable , Literal
77
8- from litellm import completion , token_counter
8+ import numpy as np
9+ from litellm import completion
910from litellm .types .utils import Message , ModelResponse
11+ from litellm .utils import token_counter
1012from PIL import Image
1113from termcolor import colored
1214
@@ -67,6 +69,10 @@ class AgentConfig:
6769Provide a concise summary that preserves all information needed to continue the task."""
6870
6971
72+ def user_message (content : str | list [dict ]) -> dict :
73+ return {"role" : "user" , "content" : content }
74+
75+
7076class ReactToolCallAgent :
7177 def __init__ (
7278 self ,
@@ -90,41 +96,34 @@ def obs_to_messages(self, obs: dict) -> list[dict]:
9096 """
9197 Convert the observation dictionary into a list of chat messages for Lite LLM
9298 """
99+ images = {k : v for k , v in obs .items () if isinstance (v , (Image .Image , np .ndarray ))}
100+ texts = {k : v for k , v in obs .items () if k not in images and v is not None and v != "" }
93101 messages = []
94- if obs .get ("goal_object" ) and not self .last_tool_call_id :
102+
103+ if not self .last_tool_call_id and (goal_obj := texts .pop ("goal_object" , None )):
95104 # its a first observation when there are no tool_call_id, so include goal
96- goal = obs ["goal_object" ][0 ]["text" ]
97- messages .append ({"role" : "user" , "content" : f"## Goal:\n { goal } " })
98- text_obs = []
99- if result := obs .get ("action_result" ):
100- text_obs .append (f"## Action Result:\n { result } " )
101- if error := obs .get ("last_action_error" ):
102- text_obs .append (f"## Action Error:\n { error } " )
103- if self .config .use_html and (html := obs .get ("pruned_html" )):
104- text_obs .append (f"## HTML:\n { html } " )
105- if self .config .use_axtree and (axtree := obs .get ("axtree_txt" )):
106- text_obs .append (f"## Accessibility Tree:\n { axtree } " )
107- content = "\n \n " .join (text_obs )
108- if content :
109- if self .last_tool_call_id :
110- message = {
111- "role" : "tool" ,
112- "tool_call_id" : self .last_tool_call_id ,
113- "content" : content ,
114- }
115- else :
116- message = {"role" : "user" , "content" : content }
117- messages .append (message )
118- if self .config .use_screenshot and (scr := obs .get ("screenshot" )):
119- if isinstance (scr , Image .Image ):
105+ goal = goal_obj [0 ]["text" ]
106+ messages .append (user_message (f"Goal: { goal } " ))
107+
108+ text = "\n \n " .join ([f"## { k } \n { v } " for k , v in texts .items ()])
109+ if self .last_tool_call_id :
110+ message = {
111+ "role" : "tool" ,
112+ "tool_call_id" : self .last_tool_call_id ,
113+ "content" : text ,
114+ }
115+ else :
116+ message = user_message (text )
117+ messages .append (message )
118+
119+ if self .config .use_screenshot :
120+ for caption , image in images .items ():
120121 image_content = [
121- {"type" : "image_url" , "image_url" : {"url" : image_to_png_base64_url (scr )}}
122+ {"type" : "text" , "text" : caption },
123+ {"type" : "image_url" , "image_url" : {"url" : image_to_png_base64_url (image )}},
122124 ]
123- messages .append ({"role" : "user" , "content" : image_content })
124- else :
125- raise ValueError (
126- f"Expected Image.Image in screenshot obs, got { type (obs ['screenshot' ])} "
127- )
125+ messages .append (user_message (image_content ))
126+
128127 return messages
129128
130129 def get_action (self , obs : dict ) -> tuple [ToolCall , dict ]:
0 commit comments