forked from GT-STAR-Lab/CIMER
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhyperparameter_sweep.py
More file actions
67 lines (52 loc) · 2.12 KB
/
Copy pathhyperparameter_sweep.py
File metadata and controls
67 lines (52 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import itertools
import os
"""
A very simple grid sweep function. I wrote it this way to make it generalizable to all scripts and param setups in this repo.
"""
# https://stackoverflow.com/a/40623158
def dict_product(dicts):
"""
list(dict_product(dict(number=[1,2], character='ab')))
[{'character': 'a', 'number': 1},
{'character': 'a', 'number': 2},
{'character': 'b', 'number': 1},
{'character': 'b', 'number': 2}]
"""
return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))
def delimit_task_id(task_id):
return task_id.split('-')[0]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--script_name', type = str, required = True) #need to be as if you're running it, e.g. ARS/ARS/ars.py
parser.add_argument('--params_path', type = str, required = True)
parser.add_argument('--run_base_name', type= str, default = None)
args = parser.parse_args()
if args.params_path is not None:
import json
params = dict(json.load(open(args.params_path, 'r')))
print(params)
ctr = 0
hyp_sweep_params = list(dict_product(params))
print(hyp_sweep_params)
exec_string_base = "python " + args.script_name
# simple grid search for now
for param_dict in hyp_sweep_params:
task_name = delimit_task_id(param_dict['task_id'])
exec_string = exec_string_base[:]
#add options
for k, v in param_dict.items():
if isinstance(v, list):
exec_string += f' --{k}'
for elem in v:
exec_string += f' {elem}'
else:
exec_string += f' --{k} {v}'
#give each run a unique name and write output to a unique report file
run_base_name = args.run_base_name
if not run_base_name:
run_base_name = task_name + '-hsweeprun'
exec_string += f' --run_name {run_base_name}-{ctr} > reports/{run_base_name}-{ctr}.out'
print(exec_string)
os.system(exec_string) #probably not the safest, but i hope nobody screws this up
ctr += 1