@@ -58,7 +58,7 @@ def hashable(s):
5858 except TypeError :
5959 return False
6060
61- def __init__ (self , dsk , low_memory_mode = False ):
61+ def __init__ (self , dsk , low_memory_mode = False , prune_depth = 0 ):
6262 self ._dsk = dsk
6363
6464 # child -> parents. I.e., which parents needs the result of child
@@ -73,9 +73,6 @@ def __init__(self, dsk, low_memory_mode=False):
7373 # key->value of its computation
7474 self ._result_of = {}
7575
76- # child -> nodes that use the child as an input, and that have not been completed
77- self ._pending_parents_of = defaultdict (lambda : set ())
78-
7976 # key->depth. The shallowest level the key is found
8077 self ._depth_of = defaultdict (lambda : float ('inf' ))
8178
@@ -86,6 +83,10 @@ def __init__(self, dsk, low_memory_mode=False):
8683 if low_memory_mode :
8784 self ._flatten_graph ()
8885
86+ self .prune_depth = prune_depth
87+ self .pending_consumers = defaultdict (int )
88+ self .pending_producers = defaultdict (lambda : set ())
89+
8990 self .initialize_graph ()
9091
9192 def left_to_compute (self ):
@@ -103,6 +104,11 @@ def initialize_graph(self):
103104 for key , sexpr in self ._working_graph .items ():
104105 self .set_relations (key , sexpr )
105106
107+ # Then initialize pending consumers if pruning is enabled
108+ if self .prune_depth > 0 :
109+ self ._initialize_pending_consumers ()
110+ self ._initialize_pending_producers ()
111+
106112 def find_dependencies (self , sexpr , depth = 0 ):
107113 dependencies = set ()
108114 if self .graph_keyp (sexpr ):
@@ -123,7 +129,53 @@ def set_relations(self, key, sexpr):
123129
124130 for c in self ._children_of [key ]:
125131 self ._parents_of [c ].add (key )
126- self ._pending_parents_of [c ].add (key )
132+
133+ def _initialize_pending_consumers (self ):
134+ """Initialize pending consumers counts based on prune_depth"""
135+ for key in self ._working_graph :
136+ if key not in self .pending_consumers :
137+ count = 0
138+ # BFS to count consumers up to prune_depth
139+ visited = set ()
140+ queue = [(c , 1 ) for c in self ._parents_of [key ]] # (consumer, depth)
141+
142+ while queue :
143+ consumer , depth = queue .pop (0 )
144+ if depth <= self .prune_depth and consumer not in visited :
145+ visited .add (consumer )
146+ count += 1
147+
148+ # Add next level consumers if we haven't reached max depth
149+ if depth < self .prune_depth :
150+ next_consumers = [(c , depth + 1 ) for c in self ._parents_of [consumer ]]
151+ queue .extend (next_consumers )
152+
153+ self .pending_consumers [key ] = count
154+
155+ def _initialize_pending_producers (self ):
156+ """Initialize pending producers based on prune_depth"""
157+ if self .prune_depth <= 0 :
158+ return
159+
160+ for key in self ._working_graph :
161+ # Use set to store unique producers
162+ producers = set ()
163+ visited = set ()
164+ queue = [(p , 1 ) for p in self ._children_of [key ]] # (producer, depth)
165+
166+ while queue :
167+ producer , depth = queue .pop (0 )
168+ if depth <= self .prune_depth and producer not in visited :
169+ visited .add (producer )
170+ producers .add (producer )
171+
172+ # Add next level producers if we haven't reached max depth
173+ if depth < self .prune_depth :
174+ next_producers = [(p , depth + 1 ) for p in self ._children_of [producer ]]
175+ queue .extend (next_producers )
176+
177+ # Store all producers for this key in pending_producers
178+ self .pending_producers [key ] = producers
127179
128180 def get_ready (self ):
129181 """ List of [(key, sexpr),...] ready for computation.
@@ -148,6 +200,7 @@ def set_result(self, key, value):
148200 of computations that become ready to be executed """
149201 rs = {}
150202 self ._result_of [key ] = value
203+
151204 for p in self ._parents_of [key ]:
152205 self ._missing_of [p ].discard (key )
153206
@@ -164,9 +217,6 @@ def set_result(self, key, value):
164217 else :
165218 rs [p ] = (p , sexpr )
166219
167- for c in self ._children_of [key ]:
168- self ._pending_parents_of [c ].discard (key )
169-
170220 return rs .values ()
171221
172222 def _flatten_graph (self ):
@@ -228,9 +278,6 @@ def get_missing_children(self, key):
228278 def get_parents (self , key ):
229279 return self ._parents_of [key ]
230280
231- def get_pending_parents (self , key ):
232- return self ._pending_parents_of [key ]
233-
234281 def set_targets (self , keys ):
235282 """ Values of keys that need to be computed. """
236283 self ._targets .update (keys )
0 commit comments