@@ -58,7 +58,7 @@ def hashable(s):
58
58
except TypeError :
59
59
return False
60
60
61
- def __init__ (self , dsk , low_memory_mode = False ):
61
+ def __init__ (self , dsk , low_memory_mode = False , prune_depth = 0 ):
62
62
self ._dsk = dsk
63
63
64
64
# child -> parents. I.e., which parents needs the result of child
@@ -73,9 +73,6 @@ def __init__(self, dsk, low_memory_mode=False):
73
73
# key->value of its computation
74
74
self ._result_of = {}
75
75
76
- # child -> nodes that use the child as an input, and that have not been completed
77
- self ._pending_parents_of = defaultdict (lambda : set ())
78
-
79
76
# key->depth. The shallowest level the key is found
80
77
self ._depth_of = defaultdict (lambda : float ('inf' ))
81
78
@@ -86,6 +83,10 @@ def __init__(self, dsk, low_memory_mode=False):
86
83
if low_memory_mode :
87
84
self ._flatten_graph ()
88
85
86
+ self .prune_depth = prune_depth
87
+ self .pending_consumers = defaultdict (int )
88
+ self .pending_producers = defaultdict (lambda : set ())
89
+
89
90
self .initialize_graph ()
90
91
91
92
def left_to_compute (self ):
@@ -103,6 +104,11 @@ def initialize_graph(self):
103
104
for key , sexpr in self ._working_graph .items ():
104
105
self .set_relations (key , sexpr )
105
106
107
+ # Then initialize pending consumers if pruning is enabled
108
+ if self .prune_depth > 0 :
109
+ self ._initialize_pending_consumers ()
110
+ self ._initialize_pending_producers ()
111
+
106
112
def find_dependencies (self , sexpr , depth = 0 ):
107
113
dependencies = set ()
108
114
if self .graph_keyp (sexpr ):
@@ -123,7 +129,53 @@ def set_relations(self, key, sexpr):
123
129
124
130
for c in self ._children_of [key ]:
125
131
self ._parents_of [c ].add (key )
126
- self ._pending_parents_of [c ].add (key )
132
+
133
+ def _initialize_pending_consumers (self ):
134
+ """Initialize pending consumers counts based on prune_depth"""
135
+ for key in self ._working_graph :
136
+ if key not in self .pending_consumers :
137
+ count = 0
138
+ # BFS to count consumers up to prune_depth
139
+ visited = set ()
140
+ queue = [(c , 1 ) for c in self ._parents_of [key ]] # (consumer, depth)
141
+
142
+ while queue :
143
+ consumer , depth = queue .pop (0 )
144
+ if depth <= self .prune_depth and consumer not in visited :
145
+ visited .add (consumer )
146
+ count += 1
147
+
148
+ # Add next level consumers if we haven't reached max depth
149
+ if depth < self .prune_depth :
150
+ next_consumers = [(c , depth + 1 ) for c in self ._parents_of [consumer ]]
151
+ queue .extend (next_consumers )
152
+
153
+ self .pending_consumers [key ] = count
154
+
155
+ def _initialize_pending_producers (self ):
156
+ """Initialize pending producers based on prune_depth"""
157
+ if self .prune_depth <= 0 :
158
+ return
159
+
160
+ for key in self ._working_graph :
161
+ # Use set to store unique producers
162
+ producers = set ()
163
+ visited = set ()
164
+ queue = [(p , 1 ) for p in self ._children_of [key ]] # (producer, depth)
165
+
166
+ while queue :
167
+ producer , depth = queue .pop (0 )
168
+ if depth <= self .prune_depth and producer not in visited :
169
+ visited .add (producer )
170
+ producers .add (producer )
171
+
172
+ # Add next level producers if we haven't reached max depth
173
+ if depth < self .prune_depth :
174
+ next_producers = [(p , depth + 1 ) for p in self ._children_of [producer ]]
175
+ queue .extend (next_producers )
176
+
177
+ # Store all producers for this key in pending_producers
178
+ self .pending_producers [key ] = producers
127
179
128
180
def get_ready (self ):
129
181
""" List of [(key, sexpr),...] ready for computation.
@@ -148,6 +200,7 @@ def set_result(self, key, value):
148
200
of computations that become ready to be executed """
149
201
rs = {}
150
202
self ._result_of [key ] = value
203
+
151
204
for p in self ._parents_of [key ]:
152
205
self ._missing_of [p ].discard (key )
153
206
@@ -164,9 +217,6 @@ def set_result(self, key, value):
164
217
else :
165
218
rs [p ] = (p , sexpr )
166
219
167
- for c in self ._children_of [key ]:
168
- self ._pending_parents_of [c ].discard (key )
169
-
170
220
return rs .values ()
171
221
172
222
def _flatten_graph (self ):
@@ -228,9 +278,6 @@ def get_missing_children(self, key):
228
278
def get_parents (self , key ):
229
279
return self ._parents_of [key ]
230
280
231
- def get_pending_parents (self , key ):
232
- return self ._pending_parents_of [key ]
233
-
234
281
def set_targets (self , keys ):
235
282
""" Values of keys that need to be computed. """
236
283
self ._targets .update (keys )
0 commit comments