Skip to content

Commit a918396

Browse files
committed
Don't measure peak mem with threads executor
1 parent f9d5952 commit a918396

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

cubed/runtime/executors/local.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor, TaskEndEvent
1919
from cubed.runtime.utils import (
2020
execution_stats,
21+
execution_timing,
2122
handle_callbacks,
2223
handle_operation_start_callbacks,
2324
profile_memray,
@@ -61,9 +62,14 @@ def execute_dag(
6162
[callback.on_task_end(event) for callback in callbacks]
6263

6364

65+
@execution_timing
66+
def run_func_threads(input, func=None, config=None, name=None, compute_id=None):
67+
return func(input, config=config)
68+
69+
6470
@profile_memray
6571
@execution_stats
66-
def run_func(input, func=None, config=None, name=None, compute_id=None):
72+
def run_func_processes(input, func=None, config=None, name=None, compute_id=None):
6773
return func(input, config=config)
6874

6975

@@ -142,7 +148,11 @@ def create_futures_func_multiprocessing(input, **kwargs):
142148

143149

144150
def pipeline_to_stream(
145-
concurrent_executor: Executor, name: str, pipeline: CubedPipeline, **kwargs
151+
concurrent_executor: Executor,
152+
run_func: Callable,
153+
name: str,
154+
pipeline: CubedPipeline,
155+
**kwargs,
146156
) -> Stream:
147157
return stream.iterate(
148158
map_unordered(
@@ -200,15 +210,17 @@ async def async_execute_dag(
200210
mp_context=context,
201211
max_tasks_per_child=max_tasks_per_child,
202212
)
213+
run_func = run_func_processes
203214
else:
204215
concurrent_executor = ThreadPoolExecutor(max_workers=max_workers)
216+
run_func = run_func_threads
205217
try:
206218
if not compute_arrays_in_parallel:
207219
# run one pipeline at a time
208220
for name, node in visit_nodes(dag, resume=resume):
209221
handle_operation_start_callbacks(callbacks, name)
210222
st = pipeline_to_stream(
211-
concurrent_executor, name, node["pipeline"], **kwargs
223+
concurrent_executor, run_func, name, node["pipeline"], **kwargs
212224
)
213225
async with st.stream() as streamer:
214226
async for _, stats in streamer:
@@ -218,7 +230,7 @@ async def async_execute_dag(
218230
# run pipelines in the same topological generation in parallel by merging their streams
219231
streams = [
220232
pipeline_to_stream(
221-
concurrent_executor, name, node["pipeline"], **kwargs
233+
concurrent_executor, run_func, name, node["pipeline"], **kwargs
222234
)
223235
for name, node in gen
224236
]

cubed/runtime/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,33 @@ def execute_with_stats(function, *args, **kwargs):
4040
)
4141

4242

43+
def execute_with_timing(function, *args, **kwargs):
44+
"""Invoke function and measure timing information.
45+
46+
Returns the result of the function call and a stats dictionary.
47+
"""
48+
49+
function_start_tstamp = time.time()
50+
result = function(*args, **kwargs)
51+
function_end_tstamp = time.time()
52+
return result, dict(
53+
function_start_tstamp=function_start_tstamp,
54+
function_end_tstamp=function_end_tstamp,
55+
)
56+
57+
4358
def execution_stats(func):
4459
"""Decorator to measure timing information and peak memory usage of a function call."""
4560

4661
return partial(execute_with_stats, func)
4762

4863

64+
def execution_timing(func):
65+
"""Decorator to measure timing information of a function call."""
66+
67+
return partial(execute_with_timing, func)
68+
69+
4970
def execute_with_memray(function, input, **kwargs):
5071
# only run memray if installed, and only for first input (for operations that run on block locations)
5172
if (

0 commit comments

Comments
 (0)