forked from Alibaba-NLP/DeepResearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_multi_react.py
More file actions
229 lines (205 loc) · 9.91 KB
/
run_multi_react.py
File metadata and controls
229 lines (205 loc) · 9.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import argparse
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import concurrent.futures
from tqdm import tqdm
import threading
from datetime import datetime
from react_agent import MultiTurnReactAgent
import time
import math
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="")
parser.add_argument("--output", type=str, default="")
parser.add_argument("--dataset", type=str, default="gaia")
parser.add_argument("--temperature", type=float, default=0.6)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--presence_penalty", type=float, default=1.1)
parser.add_argument("--max_workers", type=int, default=20)
parser.add_argument("--roll_out_count", type=int, default=3)
parser.add_argument("--total_splits", type=int, default=1)
parser.add_argument("--worker_split", type=int, default=1)
args = parser.parse_args()
model = args.model
output_base = args.output
roll_out_count = args.roll_out_count
total_splits = args.total_splits
worker_split = args.worker_split
# Validate worker_split
if worker_split < 1 or worker_split > total_splits:
print(f"Error: worker_split ({worker_split}) must be between 1 and total_splits ({total_splits})")
exit(1)
model_name = os.path.basename(model.rstrip('/'))
model_dir = os.path.join(output_base, f"{model_name}_sglang")
dataset_dir = os.path.join(model_dir, args.dataset)
os.makedirs(dataset_dir, exist_ok=True)
print(f"Model name: {model_name}")
print(f"Data set name: {args.dataset}")
print(f"Output directory: {dataset_dir}")
print(f"Number of rollouts: {roll_out_count}")
print(f"Data splitting: {worker_split}/{total_splits}")
data_filepath = f"eval_data/{args.dataset}.jsonl"
try:
if data_filepath.endswith(".json"):
with open(data_filepath, "r", encoding="utf-8") as f:
items = json.load(f)
if not isinstance(items, list):
raise ValueError("Input JSON must be a list of objects.")
if items and not isinstance(items[0], dict):
raise ValueError("Input JSON list items must be objects.")
elif data_filepath.endswith(".jsonl"):
with open(data_filepath, "r", encoding="utf-8") as f:
items = [json.loads(line) for line in f]
else:
raise ValueError("Unsupported file extension. Please use .json or .jsonl files.")
items = items
except FileNotFoundError:
print(f"Error: Input file not found at {data_filepath}")
exit(1)
except (json.JSONDecodeError, ValueError) as e:
print(f"Error reading or parsing input file {data_filepath}: {e}")
exit(1)
# Apply data splitting
total_items = len(items)
items_per_split = math.ceil(total_items / total_splits)
start_idx = (worker_split - 1) * items_per_split
end_idx = min(worker_split * items_per_split, total_items)
# Split the dataset
items = items[start_idx:end_idx]
print(f"Total items in dataset: {total_items}")
print(f"Processing items {start_idx} to {end_idx-1} ({len(items)} items)")
if total_splits > 1:
# Add split suffix to output files when using splits
output_files = {i: os.path.join(dataset_dir, f"iter{i}_split{worker_split}of{total_splits}.jsonl") for i in range(1, roll_out_count + 1)}
else:
output_files = {i: os.path.join(dataset_dir, f"iter{i}.jsonl") for i in range(1, roll_out_count + 1)}
processed_queries_per_rollout = {}
for rollout_idx in range(1, roll_out_count + 1):
output_file = output_files[rollout_idx]
processed_queries = set()
if os.path.exists(output_file):
try:
with open(output_file, "r", encoding="utf-8") as f:
for line in f:
try:
data = json.loads(line)
if "question" in data and "error" not in data:
processed_queries.add(data["question"].strip())
except json.JSONDecodeError:
print(f"Warning: Skipping invalid line in output file: {line.strip()}")
except FileNotFoundError:
pass
processed_queries_per_rollout[rollout_idx] = processed_queries
tasks_to_run_all = []
per_rollout_task_counts = {i: 0 for i in range(1, roll_out_count + 1)}
# Define ports
planning_ports = [6001, 6002, 6003, 6004, 6005, 6006, 6007, 6008]
# Round-robin state
planning_rr_idx = 0
summary_rr_idx = 0
# Sticky assignment per question
question_to_ports = {}
for rollout_idx in range(1, roll_out_count + 1):
processed_queries = processed_queries_per_rollout[rollout_idx]
for item in items:
question = item.get("question", "").strip()
if question == "":
try:
user_msg = item["messages"][1]["content"]
question = user_msg.split("User:")[1].strip() if "User:" in user_msg else user_msg
item["question"] = question
except Exception as e:
print(f"Extract question from user message failed: {e}")
if not question:
print(f"Warning: Skipping item with empty question: {item}")
continue
if question not in processed_queries:
# Ensure sticky and balanced port assignment per unique question
if question not in question_to_ports:
planning_port = planning_ports[planning_rr_idx % len(planning_ports)]
question_to_ports[question] = planning_port
planning_rr_idx += 1
planning_port = question_to_ports[question]
tasks_to_run_all.append({
"item": item.copy(),
"rollout_idx": rollout_idx,
"planning_port": planning_port,
})
per_rollout_task_counts[rollout_idx] += 1
print(f"Total questions in current split: {len(items)}")
for rollout_idx in range(1, roll_out_count + 1):
print(f"Rollout {rollout_idx}: already successfully processed: {len(processed_queries_per_rollout[rollout_idx])}, to run: {per_rollout_task_counts[rollout_idx]}")
if not tasks_to_run_all:
print("All rollouts have been completed and no execution is required.")
else:
llm_cfg = {
'model': model,
'generate_cfg': {
'max_input_tokens': 320000,
'max_retries': 10,
'temperature': args.temperature,
'top_p': args.top_p,
'presence_penalty': args.presence_penalty
},
'model_type': 'qwen_dashscope'
}
test_agent = MultiTurnReactAgent(
llm=llm_cfg,
function_list=["search", "visit", "google_scholar", "PythonInterpreter"]
)
write_locks = {i: threading.Lock() for i in range(1, roll_out_count + 1)}
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
future_to_task = {
executor.submit(
test_agent._run,
task,
model
): task for task in tasks_to_run_all
}
for future in tqdm(as_completed(future_to_task), total=len(tasks_to_run_all), desc="Processing All Rollouts"):
task_info = future_to_task[future]
rollout_idx = task_info["rollout_idx"]
output_file = output_files[rollout_idx]
try:
result = future.result()
with write_locks[rollout_idx]:
with open(output_file, "a", encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=False) + "\n")
except concurrent.futures.TimeoutError:
question = task_info["item"].get("question", "")
print(f'Timeout (>1800s): "{question}" (Rollout {rollout_idx})')
future.cancel()
error_result = {
"question": question,
"answer": task_info["item"].get("answer", ""),
"rollout_idx": rollout_idx,
"rollout_id": rollout_idx,
"error": "Timeout (>1800s)",
"messages": [],
"prediction": "[Failed]"
}
with write_locks[rollout_idx]:
with open(output_file, "a", encoding="utf-8") as f:
f.write(json.dumps(error_result, ensure_ascii=False) + "\n")
except Exception as exc:
question = task_info["item"].get("question", "")
print(f'Task for question "{question}" (Rollout {rollout_idx}) generated an exception: {exc}')
error_result = {
"question": question,
"answer": task_info["item"].get("answer", ""),
"rollout_idx": rollout_idx,
"rollout_id": rollout_idx,
"error": f"Future resolution failed: {exc}",
"messages": [],
"prediction": "[Failed]",
}
print("===============================")
print(error_result)
print("===============================")
with write_locks[rollout_idx]:
with open(output_file, "a", encoding="utf-8") as f:
f.write(json.dumps(error_result, ensure_ascii=False) + "\n")
print("\nAll tasks completed!")
print(f"\nAll {roll_out_count} rollouts completed!")