-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathjob_2_7.py
77 lines (63 loc) · 2.23 KB
/
job_2_7.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from copy import deepcopy
from pathlib import Path
from typing import Dict
from framework.cli.job import InteractiveCombinedJob, InteractiveCombinedJobSpec
logger_format = '%(asctime)s | %(levelname)s | %(module)s | %(message)s'
logger_dateformat = "%Y-%m-%d %H:%M:%S"
class Job_2_7(InteractiveCombinedJob):
# max_sphere_radius (i.e. pruning distance) -> RAM/G
def SM_RAM(self, distance: float) -> int:
if distance <= 1:
return 5
elif distance <= 1.5:
return 30
# 198 is the largest min edge length, so the threshold below which the graph becomes disconnected
elif distance <= 1.98:
return 55
elif distance <= 2:
return 60
else:
# Max
return 120
LING_RAM: Dict[str, Dict[int, int]] = {
"pmi_ngram": {
1_000: 2,
3_000: 3,
10_000: 7,
30_000: 11,
40_000: 15,
60_000: 20,
},
"ppmi_ngram": {
1_000: 2,
3_000: 3,
10_000: 5,
30_000: 7,
40_000: 9,
60_000: 11,
}
}
def __init__(self, spec: InteractiveCombinedJobSpec):
super().__init__(
script_number="2_7",
script_name="2_7_interactive_combined.py",
spec=spec)
@property
def _ram_requirement_g(self):
assert isinstance(self.spec, InteractiveCombinedJobSpec)
return self.SM_RAM(self.spec.sensorimotor_spec.max_radius) \
+ self.LING_RAM[self.spec.linguistic_spec.model_name][self.spec.linguistic_spec.n_words]
if __name__ == '__main__':
# Testing everything with a range of CCAs
ccas = [0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1]
jobs = []
s: InteractiveCombinedJobSpec
for s in InteractiveCombinedJobSpec.load_multiple(
Path(Path(__file__).parent, "job_specifications/2022-05-24 longer runs more ccas.yaml")):
for cca in ccas:
spec = deepcopy(s)
spec.cross_component_attenuation = cca
jobs.append(Job_2_7(spec))
for job in jobs:
job.run_locally(extra_arguments=["--filter_events accessible_set"])
print(f"Submitted {len(jobs)} jobs.")