Skip to content

Commit 2fc88bc

Browse files
authored
vine: file pruning by depth (#4057)
1 parent f38ba02 commit 2fc88bc

File tree

2 files changed

+68
-25
lines changed

2 files changed

+68
-25
lines changed

taskvine/src/bindings/python3/ndcctools/taskvine/compat/dask_dag.py

+58-11
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def hashable(s):
5858
except TypeError:
5959
return False
6060

61-
def __init__(self, dsk, low_memory_mode=False):
61+
def __init__(self, dsk, low_memory_mode=False, prune_depth=0):
6262
self._dsk = dsk
6363

6464
# child -> parents. I.e., which parents needs the result of child
@@ -73,9 +73,6 @@ def __init__(self, dsk, low_memory_mode=False):
7373
# key->value of its computation
7474
self._result_of = {}
7575

76-
# child -> nodes that use the child as an input, and that have not been completed
77-
self._pending_parents_of = defaultdict(lambda: set())
78-
7976
# key->depth. The shallowest level the key is found
8077
self._depth_of = defaultdict(lambda: float('inf'))
8178

@@ -86,6 +83,10 @@ def __init__(self, dsk, low_memory_mode=False):
8683
if low_memory_mode:
8784
self._flatten_graph()
8885

86+
self.prune_depth = prune_depth
87+
self.pending_consumers = defaultdict(int)
88+
self.pending_producers = defaultdict(lambda: set())
89+
8990
self.initialize_graph()
9091

9192
def left_to_compute(self):
@@ -103,6 +104,11 @@ def initialize_graph(self):
103104
for key, sexpr in self._working_graph.items():
104105
self.set_relations(key, sexpr)
105106

107+
# Then initialize pending consumers if pruning is enabled
108+
if self.prune_depth > 0:
109+
self._initialize_pending_consumers()
110+
self._initialize_pending_producers()
111+
106112
def find_dependencies(self, sexpr, depth=0):
107113
dependencies = set()
108114
if self.graph_keyp(sexpr):
@@ -123,7 +129,53 @@ def set_relations(self, key, sexpr):
123129

124130
for c in self._children_of[key]:
125131
self._parents_of[c].add(key)
126-
self._pending_parents_of[c].add(key)
132+
133+
def _initialize_pending_consumers(self):
134+
"""Initialize pending consumers counts based on prune_depth"""
135+
for key in self._working_graph:
136+
if key not in self.pending_consumers:
137+
count = 0
138+
# BFS to count consumers up to prune_depth
139+
visited = set()
140+
queue = [(c, 1) for c in self._parents_of[key]] # (consumer, depth)
141+
142+
while queue:
143+
consumer, depth = queue.pop(0)
144+
if depth <= self.prune_depth and consumer not in visited:
145+
visited.add(consumer)
146+
count += 1
147+
148+
# Add next level consumers if we haven't reached max depth
149+
if depth < self.prune_depth:
150+
next_consumers = [(c, depth + 1) for c in self._parents_of[consumer]]
151+
queue.extend(next_consumers)
152+
153+
self.pending_consumers[key] = count
154+
155+
def _initialize_pending_producers(self):
156+
"""Initialize pending producers based on prune_depth"""
157+
if self.prune_depth <= 0:
158+
return
159+
160+
for key in self._working_graph:
161+
# Use set to store unique producers
162+
producers = set()
163+
visited = set()
164+
queue = [(p, 1) for p in self._children_of[key]] # (producer, depth)
165+
166+
while queue:
167+
producer, depth = queue.pop(0)
168+
if depth <= self.prune_depth and producer not in visited:
169+
visited.add(producer)
170+
producers.add(producer)
171+
172+
# Add next level producers if we haven't reached max depth
173+
if depth < self.prune_depth:
174+
next_producers = [(p, depth + 1) for p in self._children_of[producer]]
175+
queue.extend(next_producers)
176+
177+
# Store all producers for this key in pending_producers
178+
self.pending_producers[key] = producers
127179

128180
def get_ready(self):
129181
""" List of [(key, sexpr),...] ready for computation.
@@ -148,6 +200,7 @@ def set_result(self, key, value):
148200
of computations that become ready to be executed """
149201
rs = {}
150202
self._result_of[key] = value
203+
151204
for p in self._parents_of[key]:
152205
self._missing_of[p].discard(key)
153206

@@ -164,9 +217,6 @@ def set_result(self, key, value):
164217
else:
165218
rs[p] = (p, sexpr)
166219

167-
for c in self._children_of[key]:
168-
self._pending_parents_of[c].discard(key)
169-
170220
return rs.values()
171221

172222
def _flatten_graph(self):
@@ -228,9 +278,6 @@ def get_missing_children(self, key):
228278
def get_parents(self, key):
229279
return self._parents_of[key]
230280

231-
def get_pending_parents(self, key):
232-
return self._pending_parents_of[key]
233-
234281
def set_targets(self, keys):
235282
""" Values of keys that need to be computed. """
236283
self._targets.update(keys)

taskvine/src/bindings/python3/ndcctools/taskvine/compat/dask_executor.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class DaskVine(Manager):
107107
# fn(*args) at some point during its execution to produce the dask task result.
108108
# Should return a tuple of (wrapper result, dask call result). Use for debugging.
109109
# @param wrapper_proc Function to process results from wrapper on completion. (default is print)
110-
# @param prune_files If True, remove files from the cluster after they are no longer needed.
110+
# @param prune_depth Control pruning behavior: 0 (default) - no pruning, 1 - only check direct consumers, 2+ - check consumers up to specified depth
111111
def get(self, dsk, keys, *,
112112
environment=None,
113113
extra_files=None,
@@ -132,7 +132,7 @@ def get(self, dsk, keys, *,
132132
progress_label="[green]tasks",
133133
wrapper=None,
134134
wrapper_proc=print,
135-
prune_files=False,
135+
prune_depth=0,
136136
hoisting_modules=None, # Deprecated, use lib_modules
137137
import_modules=None, # Deprecated, use lib_modules
138138
lazy_transfers=True, # Deprecated, use worker_tranfers
@@ -174,7 +174,7 @@ def get(self, dsk, keys, *,
174174
self.progress_label = progress_label
175175
self.wrapper = wrapper
176176
self.wrapper_proc = wrapper_proc
177-
self.prune_files = prune_files
177+
self.prune_depth = prune_depth
178178
self.category_info = defaultdict(lambda: {"num_tasks": 0, "total_execution_time": 0})
179179
self.max_priority = float('inf')
180180
self.min_priority = float('-inf')
@@ -212,7 +212,7 @@ def _dask_execute(self, dsk, keys):
212212
indices = {k: inds for (k, inds) in find_dask_keys(keys)}
213213
keys_flatten = indices.keys()
214214

215-
dag = DaskVineDag(dsk, low_memory_mode=self.low_memory_mode)
215+
dag = DaskVineDag(dsk, low_memory_mode=self.low_memory_mode, prune_depth=self.prune_depth)
216216
tag = f"dag-{id(dag)}"
217217

218218
# create Library if using 'function-calls' task mode.
@@ -294,8 +294,12 @@ def _dask_execute(self, dsk, keys):
294294
if t.key in dsk:
295295
bar_update(advance=1)
296296

297-
if self.prune_files:
298-
self._prune_file(dag, t.key)
297+
if self.prune_depth > 0:
298+
for p in dag.pending_producers[t.key]:
299+
dag.pending_consumers[p] -= 1
300+
if dag.pending_consumers[p] == 0:
301+
p_result = dag.get_result(p)
302+
self.prune_file(p_result._file)
299303
else:
300304
retries_left = t.decrement_retry()
301305
print(f"task id {t.id} key {t.key} failed: {t.result}. {retries_left} attempts left.\n{t.std_output}")
@@ -446,14 +450,6 @@ def _fill_key_result(self, dag, key):
446450
return raw.load()
447451
else:
448452
return raw
449-
450-
def _prune_file(self, dag, key):
451-
children = dag.get_children(key)
452-
for c in children:
453-
if len(dag.get_pending_parents(c)) == 0:
454-
c_result = dag.get_result(c)
455-
self.prune_file(c_result._file)
456-
457453
##
458454
# @class ndcctools.taskvine.dask_executor.DaskVineFile
459455
#

0 commit comments

Comments
 (0)