-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsingle_run.py
More file actions
142 lines (123 loc) · 4.7 KB
/
single_run.py
File metadata and controls
142 lines (123 loc) · 4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from controllers.env import MacOSEnv
from pathlib import Path
from utils.logger import ProjectLogger
import time
from mm_agents.agent import PromptAgent
from mm_agents.aguvis_agent import AguvisAgent
from mm_agents.uitars_agent import UITARSAgent
from mm_agents.internvl_agent import InternvlAgent
from mm_agents.simple_qwenvl_agent import SimpleQwenvlAgent
import datetime
import os
import json
import argparse
script_dir = Path(__file__).resolve()
# Logger setup
logger = ProjectLogger(log_dir=script_dir / "logs")
def wait_for_ssh(env: MacOSEnv, max_wait: int = 300, interval: int = 5):
total_waited = 0
attempt = 1
while total_waited < max_wait:
try:
logger.info(f"[SSH Attempt {attempt}] Trying to connect...")
env.connect_ssh()
# Check actual transport status
transport = env.ssh_client.get_transport() if env.ssh_client else None
if not transport or not transport.is_active():
raise ConnectionError("SSH transport not active after connect()")
logger.info("✅ SSH connected successfully.")
# wait for boot
time.sleep(15)
return
except Exception as e:
logger.warning(f"[SSH Attempt {attempt}] Failed: {type(e).__name__}: {e}")
time.sleep(interval)
total_waited += interval
attempt += 1
raise TimeoutError(f"❌ SSH connection failed after waiting {max_wait} seconds.")
def do_single_task(env: MacOSEnv, agent: PromptAgent, max_steps: int = 10):
obs = env._get_obs() # Get the initial observation
done = False
step_idx = 0
example_result_dir = "results/example_run_1"
scores = []
env.start_recording()
while not done and step_idx < max_steps:
response, actions = agent.predict(env.task.instruction, obs)
# logger.info("Response: " + response)
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action)
logger.logger.info("Reward: %.2f", reward)
logger.logger.info("Done: %s", done)
# Save screenshot and trajectory information
with open(
os.path.join(
example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"
),
"wb",
) as _f:
_f.write(obs["screenshot"])
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
}
)
)
f.write("\n")
if done:
logger.logger.info("The episode is done.")
break
step_idx += 1
result = env.evaluate_task()
logger.logger.info("Result: %.2f", result)
scores.append(result)
with open(
os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8"
) as f:
f.write(f"{result}\n")
env.end_recording(os.path.join(example_result_dir, "recording.mp4"))
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
type=str,
default="config/default_config_linux.yaml",
help="Path to env config YAML file",
)
args = parser.parse_args()
# ======
# For dedug only
# ======
# Initialize the environment with default config
macos_env = MacOSEnv(config_file=args.config_file)
# Select agent
# agent = PromptAgent()
agent = InternvlAgent(model="ScaleCUA")
# agent = AguvisAgent(executor_model="uground7b")
# agent = SimpleQwenvlAgent()
# agent = AguvisAgent(executor_model="Aguvis-72B-720P")
# agent = UITARSAgent(model="ui_tars_15_7b")
agent.reset()
# Restart the docker if needed
macos_env._reset_env()
wait_for_ssh(macos_env)
# Connect to Docker container
macos_env.connect_ssh()
task_path = Path("tasks/clock/1.json").resolve()
macos_env.init_task(task_path)
do_single_task(macos_env, agent)
# Close the SSH connection
macos_env.close_connection()
if __name__ == "__main__":
main()