1
1
import os
2
+ from pathlib import Path
3
+ from typing import Union , Optional , Dict , List
4
+
5
+ from omni_schema .datamodel import omni_schema
2
6
3
7
from src .utils .helpers import merge_dict_list , load_yaml
4
8
5
9
6
10
class LinkMLConverter :
7
11
8
- def __init__ (self , benchmark_file ):
9
- self .stage_order_map = None
12
+ def __init__ (self , benchmark_file : Path ):
10
13
self .benchmark_file = os .path .abspath (benchmark_file )
11
- self .benchmark = load_yaml (benchmark_file )
14
+ self .model = load_yaml (benchmark_file )
15
+
16
+ def get_name (self ) -> str :
17
+ """Get name of the benchmark"""
18
+
19
+ return self .model .name if self .model .name else self .model .id
20
+
21
+ def get_definition (self ) -> omni_schema .Benchmark :
22
+ """Get underlying benchmark"""
23
+
24
+ return self .model
12
25
13
- def get_benchmark_name (self ):
14
- return self . benchmark . name if self . benchmark . name else self . benchmark . id
26
+ def get_stages (self ) -> Dict [ str , omni_schema . Stage ] :
27
+ """Get benchmark stages"""
15
28
16
- def get_benchmark_definition (self ):
17
- return self .benchmark
29
+ return dict ([(x .id , x ) for x in self .model .stages ])
18
30
19
- def get_stage_id (self , stage ) :
20
- return stage . id
31
+ def get_stage (self , stage_id : str ) -> Optional [ omni_schema . Stage ] :
32
+ """Get stage by stage_id"""
21
33
22
- def get_module_id (self , module ):
23
- return module .id
34
+ return self .get_stages ()[stage_id ]
24
35
25
- def get_benchmark_stages (self ) :
26
- return dict ([( x . id , x ) for x in self . benchmark . stages ])
36
+ def get_stage_by_output (self , output_id : str ) -> Optional [ omni_schema . Stage ] :
37
+ """Get stage that returns output with output_id"""
27
38
28
- def get_benchmark_stage (self , stage_id ):
29
- stages = self .get_benchmark_stages ().values ()
30
- return next (stage for stage in stages if stage .id == stage_id )
39
+ stage_by_output : dict = {}
40
+ for stage_id , stage in self .get_stages ().items ():
41
+ stage_by_output .update ({output .id : stage for output in stage .outputs })
42
+
43
+ return stage_by_output .get (output_id )
44
+
45
+ def get_modules_by_stage (self , stage : Union [str , omni_schema .Stage ]) -> Dict [str , omni_schema .Module ]:
46
+ """Get modules by stage/stage_id"""
47
+
48
+ if isinstance (stage , str ):
49
+ stage = self .get_stages ()[stage ]
31
50
32
- def get_modules_by_stage (self , stage ):
33
51
return dict ([(x .id , x ) for x in stage .modules ])
34
52
35
- def get_stage_implicit_inputs (self , stage ):
53
+ def get_stage_implicit_inputs (self , stage : Union [str , omni_schema .Stage ]) -> List [str ]:
54
+ """Get implicit inputs of a stage by stage/stage_id"""
55
+
36
56
if isinstance (stage , str ):
37
- stage = self .get_benchmark_stages ()[stage ]
57
+ stage = self .get_stages ()[stage ]
38
58
39
59
return [input .entries for input in stage .inputs ]
40
60
41
- def get_inputs_stage (self , implicit_inputs ):
42
- stages_map = {key : None for key in implicit_inputs }
43
- if implicit_inputs is not None :
44
- all_stages = self .get_benchmark_stages ()
45
- all_stages_outputs = []
46
- for stage_id in all_stages :
47
- outputs = self .get_stage_outputs (stage = stage_id )
48
- outputs = {key : stage_id for key , value in outputs .items ()}
49
- all_stages_outputs .append (outputs )
50
-
51
- all_stages_outputs = merge_dict_list (all_stages_outputs )
52
- for in_deliverable in implicit_inputs :
53
- # beware stage needs to be substituted
54
- curr_output = all_stages_outputs [in_deliverable ]
55
-
56
- stages_map [in_deliverable ] = curr_output
57
-
58
- return stages_map
59
-
60
- def get_stage_explicit_inputs (self , implicit_inputs ):
61
- explicit = {key : None for key in implicit_inputs }
62
- if implicit_inputs is not None :
63
- all_stages = self .get_benchmark_stages ()
64
- all_stages_outputs = []
65
- for stage_id in all_stages :
66
- outputs = self .get_stage_outputs (stage = stage_id )
67
- outputs = {
61
+ def get_explicit_inputs (self , input_ids : List [str ]) -> Dict [str , str ]:
62
+ """Get explicit inputs of a stage by input_id(s)"""
63
+
64
+ all_stages_outputs = []
65
+ for stage_id in self .get_stages ():
66
+ outputs = self .get_stage_outputs (stage = stage_id )
67
+ outputs = {
68
68
key : value .format (
69
69
input = "{input}" ,
70
70
stage = stage_id ,
@@ -74,104 +74,89 @@ def get_stage_explicit_inputs(self, implicit_inputs):
74
74
)
75
75
for key , value in outputs .items ()
76
76
}
77
- all_stages_outputs .append (outputs )
77
+ all_stages_outputs .append (outputs )
78
78
79
- all_stages_outputs = merge_dict_list (all_stages_outputs )
80
- for in_deliverable in implicit_inputs :
81
- # beware stage needs to be substituted
82
- curr_output = all_stages_outputs [in_deliverable ]
79
+ all_stages_outputs = merge_dict_list (all_stages_outputs )
83
80
84
- explicit [in_deliverable ] = curr_output
81
+ explicit = {key : None for key in input_ids }
82
+ for in_deliverable in input_ids :
83
+ # beware stage needs to be substituted
84
+ curr_output = all_stages_outputs [in_deliverable ]
85
+
86
+ explicit [in_deliverable ] = curr_output
85
87
86
88
return explicit
87
89
88
- def get_stage_outputs (self , stage ):
90
+ def get_stage_outputs (self , stage : Union [str , omni_schema .Stage ]) -> Dict [str , str ]:
91
+ """Get outputs of a stage by stage/stage_id"""
92
+
89
93
if isinstance (stage , str ):
90
- stage = self .get_benchmark_stages ()[stage ]
94
+ stage = self .get_stages ()[stage ]
91
95
92
96
return dict ([(output .id , output .path ) for output in stage .outputs ])
93
97
94
- def get_module_excludes (self , module ):
98
+ def get_output_stage (self , output_id : str ) -> omni_schema .Stage :
99
+ """Get stage that returns output with out_id"""
100
+
101
+ stage_by_output : dict = {}
102
+ for stage in self .model .stages :
103
+ stage_by_output .update ({out .id : stage for out in stage .outputs })
104
+
105
+ return stage_by_output .get (output_id )
106
+
107
+ def get_module_excludes (self , module : Union [str , omni_schema .Module ]) -> List [str ]:
108
+ """Get module excludes by module/module_id"""
109
+
95
110
if isinstance (module , str ):
96
- module = self .get_benchmark_modules ()[module ]
111
+ module = self .get_modules ()[module ]
97
112
98
113
return module .exclude
99
114
100
- def get_module_parameters (self , module ):
115
+ def get_module_parameters (self , module : Union [str , omni_schema .Module ]) -> List [str ]:
116
+ """Get module parameters by module/module_id"""
117
+
118
+ if isinstance (module , str ):
119
+ module = self .get_modules ()[module ]
120
+
101
121
params = None
102
122
if module .parameters is not None :
103
123
params = [x .values for x in module .parameters ]
104
124
105
125
return params
106
126
107
- def get_module_repository (self , module ):
127
+ def get_module_repository (self , module : Union [str , omni_schema .Module ]) -> omni_schema .Repository :
128
+ """Get module repository by module/module_id"""
129
+
130
+ if isinstance (module , str ):
131
+ module = self .get_modules ()[module ]
132
+
108
133
return module .repository
109
134
110
- def is_initial (self , stage ):
135
+ def is_initial (self , stage : omni_schema .Stage ) -> bool :
136
+ """Check if stage is initial"""
137
+
111
138
if stage .inputs is None or len (stage .inputs ) == 0 :
112
139
return True
113
140
else :
114
141
return False
115
142
116
- def get_after (self , stage ):
117
- return stage .after
118
-
119
- def get_stage_ids (self ):
120
- return [x .id for x in self .benchmark .stages ]
143
+ def get_outputs (self ) -> Dict [str , str ]:
144
+ """Get outputs"""
121
145
122
- def get_module_ids (self ):
123
- module_ids = []
124
- for stage in self .benchmark .stages :
125
- for module in stage .modules :
126
- module_ids .append (module .id )
127
-
128
- return module_ids
129
-
130
- def get_output_ids (self ):
131
- output_ids = []
132
- for stage in self .benchmark .stages :
146
+ outputs = {}
147
+ for stage_id , stage in self .get_stages ().items ():
133
148
for output in stage .outputs :
134
- output_ids . append ( output .id )
149
+ outputs [ output .id ] = output
135
150
136
- return output_ids
151
+ return outputs
137
152
138
- def get_initial_datasets (self ):
139
- stages = self .get_benchmark_stages ()
140
- for stage_id in stages :
141
- stage = stages [stage_id ]
142
- if self .is_initial (stage ):
143
- return self .get_modules_by_stage (stage )
153
+ def get_modules (self ) -> Dict [str , omni_schema .Module ]:
154
+ """Get modules"""
144
155
145
- def get_initial_stage (self ):
146
- stages = self .get_benchmark_stages ()
147
- for stage_id in stages :
148
- stage = stages [stage_id ]
149
- if self .is_initial (stage ):
150
- return stage
151
-
152
- def get_benchmark_modules (self ):
153
156
modules = {}
154
- stages = self .get_benchmark_stages ()
155
- for stage_id in stages :
156
- stage = stages [stage_id ]
157
+
158
+ for stage_id , stage in self .get_stages ().items ():
157
159
modules_in_stage = self .get_modules_by_stage (stage )
158
160
modules .update (modules_in_stage )
159
161
160
162
return modules
161
-
162
- def stage_order (self , element ):
163
- if self .stage_order_map is None :
164
- self .stage_order_map = self ._compute_stage_order ()
165
-
166
- return self .stage_order_map .get (element )
167
-
168
- def _compute_stage_order (self ):
169
- stages = list (self .get_benchmark_stages ().values ())
170
- stage_order_map = {
171
- self .get_stage_id (stage ): pos for pos , stage in enumerate (stages )
172
- }
173
- # FIXME very rudimentary computation of ordering
174
- # FIXME Might be more complex in future benchmarking scenarios
175
- # Assuming the order in which stages appear in the benchmark YAML is the actual order of the stages during execution
176
-
177
- return stage_order_map
0 commit comments