Skip to content
This repository was archived by the owner on Aug 27, 2024. It is now read-only.

Commit 9f2e132

Browse files
author
Daniel Incicau
committed
Some refactoring + include Almut's stage_order code
1 parent 7ed2267 commit 9f2e132

File tree

7 files changed

+172
-165
lines changed

7 files changed

+172
-165
lines changed

main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def main(benchmark_file):
1919
converter = benchmark.get_converter()
2020
print(benchmark.get_definition())
2121

22-
stages = converter.get_benchmark_stages()
22+
stages = converter.get_stages()
2323
for stage_id in stages:
2424
stage = stages[stage_id]
2525
stage_name = stage["name"]
@@ -31,7 +31,7 @@ def main(benchmark_file):
3131
print(
3232
" Explicit inputs:\n",
3333
[
34-
converter.get_stage_explicit_inputs(i)
34+
converter.get_explicit_inputs(i)
3535
for i in converter.get_stage_implicit_inputs(stage)
3636
],
3737
)
@@ -58,7 +58,7 @@ def main(benchmark_file):
5858
outputs_paths = sorted(benchmark.get_output_paths())
5959
print("All output paths:", outputs_paths)
6060

61-
benchmark.plot_graph()
61+
benchmark.plot_benchmark_graph()
6262

6363
# Serialize workflow to Snakefile
6464
workflow = SnakemakeEngine()

src/converter/converter.py

+96-111
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,70 @@
11
import os
2+
from pathlib import Path
3+
from typing import Union, Optional, Dict, List
4+
5+
from omni_schema.datamodel import omni_schema
26

37
from src.utils.helpers import merge_dict_list, load_yaml
48

59

610
class LinkMLConverter:
711

8-
def __init__(self, benchmark_file):
9-
self.stage_order_map = None
12+
def __init__(self, benchmark_file: Path):
1013
self.benchmark_file = os.path.abspath(benchmark_file)
11-
self.benchmark = load_yaml(benchmark_file)
14+
self.model = load_yaml(benchmark_file)
15+
16+
def get_name(self) -> str:
17+
"""Get name of the benchmark"""
18+
19+
return self.model.name if self.model.name else self.model.id
20+
21+
def get_definition(self) -> omni_schema.Benchmark:
22+
"""Get underlying benchmark"""
23+
24+
return self.model
1225

13-
def get_benchmark_name(self):
14-
return self.benchmark.name if self.benchmark.name else self.benchmark.id
26+
def get_stages(self) -> Dict[str, omni_schema.Stage]:
27+
"""Get benchmark stages"""
1528

16-
def get_benchmark_definition(self):
17-
return self.benchmark
29+
return dict([(x.id, x) for x in self.model.stages])
1830

19-
def get_stage_id(self, stage):
20-
return stage.id
31+
def get_stage(self, stage_id: str) -> Optional[omni_schema.Stage]:
32+
"""Get stage by stage_id"""
2133

22-
def get_module_id(self, module):
23-
return module.id
34+
return self.get_stages()[stage_id]
2435

25-
def get_benchmark_stages(self):
26-
return dict([(x.id, x) for x in self.benchmark.stages])
36+
def get_stage_by_output(self, output_id: str) -> Optional[omni_schema.Stage]:
37+
"""Get stage that returns output with output_id"""
2738

28-
def get_benchmark_stage(self, stage_id):
29-
stages = self.get_benchmark_stages().values()
30-
return next(stage for stage in stages if stage.id == stage_id)
39+
stage_by_output: dict = {}
40+
for stage_id, stage in self.get_stages().items():
41+
stage_by_output.update({output.id: stage for output in stage.outputs})
42+
43+
return stage_by_output.get(output_id)
44+
45+
def get_modules_by_stage(self, stage: Union[str, omni_schema.Stage]) -> Dict[str, omni_schema.Module]:
46+
"""Get modules by stage/stage_id"""
47+
48+
if isinstance(stage, str):
49+
stage = self.get_stages()[stage]
3150

32-
def get_modules_by_stage(self, stage):
3351
return dict([(x.id, x) for x in stage.modules])
3452

35-
def get_stage_implicit_inputs(self, stage):
53+
def get_stage_implicit_inputs(self, stage: Union[str, omni_schema.Stage]) -> List[str]:
54+
"""Get implicit inputs of a stage by stage/stage_id"""
55+
3656
if isinstance(stage, str):
37-
stage = self.get_benchmark_stages()[stage]
57+
stage = self.get_stages()[stage]
3858

3959
return [input.entries for input in stage.inputs]
4060

41-
def get_inputs_stage(self, implicit_inputs):
42-
stages_map = {key: None for key in implicit_inputs}
43-
if implicit_inputs is not None:
44-
all_stages = self.get_benchmark_stages()
45-
all_stages_outputs = []
46-
for stage_id in all_stages:
47-
outputs = self.get_stage_outputs(stage=stage_id)
48-
outputs = {key: stage_id for key, value in outputs.items()}
49-
all_stages_outputs.append(outputs)
50-
51-
all_stages_outputs = merge_dict_list(all_stages_outputs)
52-
for in_deliverable in implicit_inputs:
53-
# beware stage needs to be substituted
54-
curr_output = all_stages_outputs[in_deliverable]
55-
56-
stages_map[in_deliverable] = curr_output
57-
58-
return stages_map
59-
60-
def get_stage_explicit_inputs(self, implicit_inputs):
61-
explicit = {key: None for key in implicit_inputs}
62-
if implicit_inputs is not None:
63-
all_stages = self.get_benchmark_stages()
64-
all_stages_outputs = []
65-
for stage_id in all_stages:
66-
outputs = self.get_stage_outputs(stage=stage_id)
67-
outputs = {
61+
def get_explicit_inputs(self, input_ids: List[str]) -> Dict[str, str]:
62+
"""Get explicit inputs of a stage by input_id(s)"""
63+
64+
all_stages_outputs = []
65+
for stage_id in self.get_stages():
66+
outputs = self.get_stage_outputs(stage=stage_id)
67+
outputs = {
6868
key: value.format(
6969
input="{input}",
7070
stage=stage_id,
@@ -74,104 +74,89 @@ def get_stage_explicit_inputs(self, implicit_inputs):
7474
)
7575
for key, value in outputs.items()
7676
}
77-
all_stages_outputs.append(outputs)
77+
all_stages_outputs.append(outputs)
7878

79-
all_stages_outputs = merge_dict_list(all_stages_outputs)
80-
for in_deliverable in implicit_inputs:
81-
# beware stage needs to be substituted
82-
curr_output = all_stages_outputs[in_deliverable]
79+
all_stages_outputs = merge_dict_list(all_stages_outputs)
8380

84-
explicit[in_deliverable] = curr_output
81+
explicit = {key: None for key in input_ids}
82+
for in_deliverable in input_ids:
83+
# beware stage needs to be substituted
84+
curr_output = all_stages_outputs[in_deliverable]
85+
86+
explicit[in_deliverable] = curr_output
8587

8688
return explicit
8789

88-
def get_stage_outputs(self, stage):
90+
def get_stage_outputs(self, stage: Union[str, omni_schema.Stage]) -> Dict[str, str]:
91+
"""Get outputs of a stage by stage/stage_id"""
92+
8993
if isinstance(stage, str):
90-
stage = self.get_benchmark_stages()[stage]
94+
stage = self.get_stages()[stage]
9195

9296
return dict([(output.id, output.path) for output in stage.outputs])
9397

94-
def get_module_excludes(self, module):
98+
def get_output_stage(self, output_id: str) -> omni_schema.Stage:
99+
"""Get stage that returns output with out_id"""
100+
101+
stage_by_output: dict = {}
102+
for stage in self.model.stages:
103+
stage_by_output.update({out.id: stage for out in stage.outputs})
104+
105+
return stage_by_output.get(output_id)
106+
107+
def get_module_excludes(self, module: Union[str, omni_schema.Module]) -> List[str]:
108+
"""Get module excludes by module/module_id"""
109+
95110
if isinstance(module, str):
96-
module = self.get_benchmark_modules()[module]
111+
module = self.get_modules()[module]
97112

98113
return module.exclude
99114

100-
def get_module_parameters(self, module):
115+
def get_module_parameters(self, module: Union[str, omni_schema.Module]) -> List[str]:
116+
"""Get module parameters by module/module_id"""
117+
118+
if isinstance(module, str):
119+
module = self.get_modules()[module]
120+
101121
params = None
102122
if module.parameters is not None:
103123
params = [x.values for x in module.parameters]
104124

105125
return params
106126

107-
def get_module_repository(self, module):
127+
def get_module_repository(self, module: Union[str, omni_schema.Module]) -> omni_schema.Repository:
128+
"""Get module repository by module/module_id"""
129+
130+
if isinstance(module, str):
131+
module = self.get_modules()[module]
132+
108133
return module.repository
109134

110-
def is_initial(self, stage):
135+
def is_initial(self, stage: omni_schema.Stage) -> bool:
136+
"""Check if stage is initial"""
137+
111138
if stage.inputs is None or len(stage.inputs) == 0:
112139
return True
113140
else:
114141
return False
115142

116-
def get_after(self, stage):
117-
return stage.after
118-
119-
def get_stage_ids(self):
120-
return [x.id for x in self.benchmark.stages]
143+
def get_outputs(self) -> Dict[str, str]:
144+
"""Get outputs"""
121145

122-
def get_module_ids(self):
123-
module_ids = []
124-
for stage in self.benchmark.stages:
125-
for module in stage.modules:
126-
module_ids.append(module.id)
127-
128-
return module_ids
129-
130-
def get_output_ids(self):
131-
output_ids = []
132-
for stage in self.benchmark.stages:
146+
outputs = {}
147+
for stage_id, stage in self.get_stages().items():
133148
for output in stage.outputs:
134-
output_ids.append(output.id)
149+
outputs[output.id] = output
135150

136-
return output_ids
151+
return outputs
137152

138-
def get_initial_datasets(self):
139-
stages = self.get_benchmark_stages()
140-
for stage_id in stages:
141-
stage = stages[stage_id]
142-
if self.is_initial(stage):
143-
return self.get_modules_by_stage(stage)
153+
def get_modules(self) -> Dict[str, omni_schema.Module]:
154+
"""Get modules"""
144155

145-
def get_initial_stage(self):
146-
stages = self.get_benchmark_stages()
147-
for stage_id in stages:
148-
stage = stages[stage_id]
149-
if self.is_initial(stage):
150-
return stage
151-
152-
def get_benchmark_modules(self):
153156
modules = {}
154-
stages = self.get_benchmark_stages()
155-
for stage_id in stages:
156-
stage = stages[stage_id]
157+
158+
for stage_id, stage in self.get_stages().items():
157159
modules_in_stage = self.get_modules_by_stage(stage)
158160
modules.update(modules_in_stage)
159161

160162
return modules
161-
162-
def stage_order(self, element):
163-
if self.stage_order_map is None:
164-
self.stage_order_map = self._compute_stage_order()
165-
166-
return self.stage_order_map.get(element)
167-
168-
def _compute_stage_order(self):
169-
stages = list(self.get_benchmark_stages().values())
170-
stage_order_map = {
171-
self.get_stage_id(stage): pos for pos, stage in enumerate(stages)
172-
}
173-
# FIXME very rudimentary computation of ordering
174-
# FIXME Might be more complex in future benchmarking scenarios
175-
# Assuming the order in which stages appear in the benchmark YAML is the actual order of the stages during execution
176-
177-
return stage_order_map

src/model/benchmark.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@ def __init__(self, benchmark_yaml: Path, out_dir: str = "out"):
1414

1515
self.converter = converter
1616
self.out_dir = out_dir
17-
self.G = dag.build_dag_from_definition(converter, self.out_dir)
17+
self.G = dag.build_benchmark_dag(converter, self.out_dir)
1818

1919
self.execution_paths = None
2020

2121
def get_converter(self):
2222
return self.converter
2323

2424
def get_benchmark_name(self):
25-
return self.converter.get_benchmark_name()
25+
return self.converter.get_name()
2626

2727
def get_definition(self):
28-
return self.converter.get_benchmark_definition()
28+
return self.converter.get_definition()
2929

3030
def get_definition_file(self):
3131
return self.converter.benchmark_file
@@ -34,7 +34,7 @@ def get_nodes(self):
3434
return list(self.G.nodes)
3535

3636
def get_stage_ids(self):
37-
return self.converter.get_stage_ids()
37+
return self.converter.get_stages().keys()
3838

3939
def get_node_by_id(self, node_id):
4040
for node in self.G.nodes:
@@ -63,22 +63,22 @@ def get_output_paths(self):
6363
return set(output_paths)
6464

6565
def get_explicit_inputs(self, stage_id: str, test: bool = True):
66-
stage = self.converter.get_benchmark_stage(stage_id)
66+
stage = self.converter.get_stage(stage_id)
6767
implicit_inputs = self.converter.get_stage_implicit_inputs(stage)
6868
explicit_inputs = [
69-
self.converter.get_stage_explicit_inputs(i) for i in implicit_inputs
69+
self.converter.get_explicit_inputs(i) for i in implicit_inputs
7070
]
7171
return explicit_inputs
7272

7373
def get_explicit_outputs(self, stage_id: str):
74-
stage = self.converter.get_benchmark_stage(stage_id)
74+
stage = self.converter.get_stage(stage_id)
7575
return self.converter.get_stage_outputs(stage)
7676

7777
def get_available_parameter(self, module_id: str):
7878
node = next(node for node in self.G.nodes if node.module_id == module_id)
7979
return node.get_parameters()
8080

81-
def plot_graph(self):
81+
def plot_benchmark_graph(self):
8282
dag.plot_graph(
8383
self.G, output_file="output_dag.png", scale_factor=1.5, node_spacing=0.2
8484
)
@@ -102,7 +102,7 @@ def _generate_execution_paths(self):
102102

103103
def _get_path_exclusions(self):
104104
path_exclusions = {}
105-
stages = self.converter.get_benchmark_stages()
105+
stages = self.converter.get_stages()
106106
for stage_id in stages:
107107
stage = stages[stage_id]
108108

0 commit comments

Comments
 (0)