18
18
from cubed .runtime .types import Callback , CubedPipeline , DagExecutor , TaskEndEvent
19
19
from cubed .runtime .utils import (
20
20
execution_stats ,
21
+ execution_timing ,
21
22
handle_callbacks ,
22
23
handle_operation_start_callbacks ,
23
24
profile_memray ,
@@ -61,9 +62,14 @@ def execute_dag(
61
62
[callback .on_task_end (event ) for callback in callbacks ]
62
63
63
64
65
+ @execution_timing
66
+ def run_func_threads (input , func = None , config = None , name = None , compute_id = None ):
67
+ return func (input , config = config )
68
+
69
+
64
70
@profile_memray
65
71
@execution_stats
66
- def run_func (input , func = None , config = None , name = None , compute_id = None ):
72
+ def run_func_processes (input , func = None , config = None , name = None , compute_id = None ):
67
73
return func (input , config = config )
68
74
69
75
@@ -142,7 +148,11 @@ def create_futures_func_multiprocessing(input, **kwargs):
142
148
143
149
144
150
def pipeline_to_stream (
145
- concurrent_executor : Executor , name : str , pipeline : CubedPipeline , ** kwargs
151
+ concurrent_executor : Executor ,
152
+ run_func : Callable ,
153
+ name : str ,
154
+ pipeline : CubedPipeline ,
155
+ ** kwargs ,
146
156
) -> Stream :
147
157
return stream .iterate (
148
158
map_unordered (
@@ -200,15 +210,17 @@ async def async_execute_dag(
200
210
mp_context = context ,
201
211
max_tasks_per_child = max_tasks_per_child ,
202
212
)
213
+ run_func = run_func_processes
203
214
else :
204
215
concurrent_executor = ThreadPoolExecutor (max_workers = max_workers )
216
+ run_func = run_func_threads
205
217
try :
206
218
if not compute_arrays_in_parallel :
207
219
# run one pipeline at a time
208
220
for name , node in visit_nodes (dag , resume = resume ):
209
221
handle_operation_start_callbacks (callbacks , name )
210
222
st = pipeline_to_stream (
211
- concurrent_executor , name , node ["pipeline" ], ** kwargs
223
+ concurrent_executor , run_func , name , node ["pipeline" ], ** kwargs
212
224
)
213
225
async with st .stream () as streamer :
214
226
async for _ , stats in streamer :
@@ -218,7 +230,7 @@ async def async_execute_dag(
218
230
# run pipelines in the same topological generation in parallel by merging their streams
219
231
streams = [
220
232
pipeline_to_stream (
221
- concurrent_executor , name , node ["pipeline" ], ** kwargs
233
+ concurrent_executor , run_func , name , node ["pipeline" ], ** kwargs
222
234
)
223
235
for name , node in gen
224
236
]
0 commit comments