13
13
# limitations under the License.
14
14
15
15
import weakref
16
- from typing import NamedTuple , Optional
16
+ from abc import ABC
17
+ from typing import NamedTuple , Optional , Type , Set
17
18
18
19
import numpy as np
19
20
from pandas .api .types import is_scalar
20
21
21
22
from .... import dataframe as md
22
- from ....core import Tileable , get_output_types , ENTITY_TYPE
23
+ from ....core import Tileable , get_output_types , ENTITY_TYPE , TileableGraph
24
+ from ....core .graph import EntityGraph
23
25
from ....dataframe .arithmetic .core import DataFrameUnaryUfunc , DataFrameBinopUfunc
24
26
from ....dataframe .base .eval import DataFrameEval
25
27
from ....dataframe .indexing .getitem import DataFrameIndex
26
28
from ....dataframe .indexing .setitem import DataFrameSetitem
27
- from ....typing import OperandType
29
+ from ....typing import OperandType , EntityType
28
30
from ....utils import implements
29
- from ..core import OptimizationRecord , OptimizationRecordType
31
+ from ..core import OptimizationRecord , OptimizationRecordType , OptimizationRecords
30
32
from ..tileable .core import register_operand_based_optimization_rule
31
33
from .core import OperandBasedOptimizationRule
32
34
@@ -66,8 +68,70 @@ def builder(lhs: str, rhs: str):
66
68
_extract_result_cache = weakref .WeakKeyDictionary ()
67
69
68
70
71
+ class _EvalRewriteOptimizationRule (OperandBasedOptimizationRule , ABC ):
72
+ def __init__ (
73
+ self ,
74
+ graph : EntityGraph ,
75
+ records : OptimizationRecords ,
76
+ optimizer_cls : Type ["Optimizer" ],
77
+ ):
78
+ super ().__init__ (graph , records , optimizer_cls )
79
+ self ._marked_predecessors = dict ()
80
+
81
+ def _mark_predecessor (self , node : EntityType , predecessor : EntityType ):
82
+ pred_original = self ._records .get_original_entity (predecessor , predecessor )
83
+ if predecessor not in self ._marked_predecessors :
84
+ self ._marked_predecessors [pred_original ] = {node }
85
+ else :
86
+ self ._marked_predecessors [pred_original ].add (node )
87
+
88
+ def _find_nodes_to_remove (self , node : EntityType ) -> Set [EntityType ]:
89
+ node = self ._records .get_optimization_result (node ) or node
90
+ removed_nodes = {node }
91
+ results_set = set (self ._graph .results )
92
+ removed_pairs = []
93
+ for pred in self ._graph .iter_predecessors (node ):
94
+ pred_original = self ._records .get_original_entity (pred , pred )
95
+ pred_opt = self ._records .get_optimization_result (pred , pred )
96
+
97
+ if pred_opt in results_set or pred_original in results_set :
98
+ continue
99
+
100
+ affect_succ = self ._marked_predecessors .get (pred_original ) or []
101
+ affect_succ_opt = [
102
+ self ._records .get_optimization_result (s , s ) for s in affect_succ
103
+ ]
104
+ if all (s in affect_succ_opt for s in self ._graph .iter_successors (pred )):
105
+ removed_pairs .append ((pred_original , pred_opt ))
106
+
107
+ for pred_original , pred_opt in removed_pairs :
108
+ removed_nodes .add (pred_opt )
109
+ self ._records .append_record (
110
+ OptimizationRecord (pred_original , None , OptimizationRecordType .delete )
111
+ )
112
+ return removed_nodes
113
+
114
+ def _replace_with_new_node (self , original_node : EntityType , new_node : EntityType ):
115
+ # Find all the nodes to remove
116
+ nodes_to_remove = self ._find_nodes_to_remove (original_node )
117
+
118
+ # Build the replaced subgraph
119
+ subgraph = TileableGraph ()
120
+ subgraph .add_node (new_node )
121
+
122
+ new_results = [new_node ] if new_node in self ._graph .results else None
123
+ self ._replace_subgraph (subgraph , nodes_to_remove , new_results )
124
+ self ._records .append_record (
125
+ OptimizationRecord (
126
+ self ._records .get_original_entity (original_node , original_node ),
127
+ new_node ,
128
+ OptimizationRecordType .replace ,
129
+ )
130
+ )
131
+
132
+
69
133
@register_operand_based_optimization_rule ([DataFrameUnaryUfunc , DataFrameBinopUfunc ])
70
- class SeriesArithmeticToEval (OperandBasedOptimizationRule ):
134
+ class SeriesArithmeticToEval (_EvalRewriteOptimizationRule ):
71
135
_var_counter = 0
72
136
73
137
@classmethod
@@ -151,7 +215,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord:
151
215
if in_tileable is None :
152
216
return EvalExtractRecord ()
153
217
154
- self ._add_collapsable_predecessor (tileable , op .inputs [0 ])
218
+ self ._mark_predecessor (tileable , op .inputs [0 ])
155
219
return EvalExtractRecord (
156
220
in_tileable , _func_name_to_builder [func_name ](expr ), variables
157
221
)
@@ -164,10 +228,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord:
164
228
165
229
lhs_tileable , lhs_expr , lhs_vars = self ._extract_eval_expression (op .lhs )
166
230
if lhs_tileable is not None :
167
- self ._add_collapsable_predecessor (tileable , op .lhs )
231
+ self ._mark_predecessor (tileable , op .lhs )
168
232
rhs_tileable , rhs_expr , rhs_vars = self ._extract_eval_expression (op .rhs )
169
233
if rhs_tileable is not None :
170
- self ._add_collapsable_predecessor (tileable , op .rhs )
234
+ self ._mark_predecessor (tileable , op .rhs )
171
235
172
236
if lhs_expr is None or rhs_expr is None :
173
237
return EvalExtractRecord ()
@@ -204,24 +268,10 @@ def apply_to_operand(self, op: OperandType):
204
268
new_node = new_op .new_tileable (
205
269
[opt_in_tileable ], _key = node .key , _id = node .id , ** node .params
206
270
).data
271
+ self ._replace_with_new_node (node , new_node )
207
272
208
- self ._remove_collapsable_predecessors (node )
209
- self ._replace_node (node , new_node )
210
- self ._graph .add_edge (opt_in_tileable , new_node )
211
273
212
- self ._records .append_record (
213
- OptimizationRecord (node , new_node , OptimizationRecordType .replace )
214
- )
215
-
216
- # check node if it's in result
217
- try :
218
- i = self ._graph .results .index (node )
219
- self ._graph .results [i ] = new_node
220
- except ValueError :
221
- pass
222
-
223
-
224
- class _DataFrameEvalRewriteRule (OperandBasedOptimizationRule ):
274
+ class _DataFrameEvalRewriteRule (_EvalRewriteOptimizationRule ):
225
275
@implements (OperandBasedOptimizationRule .match_operand )
226
276
def match_operand (self , op : OperandType ) -> bool :
227
277
optimized_eval_op = self ._get_optimized_eval_op (op )
@@ -245,16 +295,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
245
295
def _get_input_columnar_node (self , op : OperandType ) -> ENTITY_TYPE :
246
296
raise NotImplementedError
247
297
248
- def _update_op_node (self , old_node : ENTITY_TYPE , new_node : ENTITY_TYPE ):
249
- self ._replace_node (old_node , new_node )
250
- for in_tileable in new_node .inputs :
251
- self ._graph .add_edge (in_tileable , new_node )
252
-
253
- original_node = self ._records .get_original_entity (old_node , old_node )
254
- self ._records .append_record (
255
- OptimizationRecord (original_node , new_node , OptimizationRecordType .replace )
256
- )
257
-
258
298
@implements (OperandBasedOptimizationRule .apply_to_operand )
259
299
def apply_to_operand (self , op : DataFrameIndex ):
260
300
node = op .outputs [0 ]
@@ -268,10 +308,8 @@ def apply_to_operand(self, op: DataFrameIndex):
268
308
new_node = new_op .new_tileable (
269
309
[opt_in_tileable ], _key = node .key , _id = node .id , ** node .params
270
310
).data
271
-
272
- self ._add_collapsable_predecessor (node , in_columnar_node )
273
- self ._remove_collapsable_predecessors (node )
274
- self ._update_op_node (node , new_node )
311
+ self ._mark_predecessor (node , in_columnar_node )
312
+ self ._replace_with_new_node (node , new_node )
275
313
276
314
277
315
@register_operand_based_optimization_rule ([DataFrameIndex ])
@@ -360,7 +398,5 @@ def apply_to_operand(self, op: DataFrameIndex):
360
398
new_node = new_op .new_tileable (
361
399
pred_opt_node .inputs , _key = node .key , _id = node .id , ** node .params
362
400
).data
363
-
364
- self ._add_collapsable_predecessor (opt_node , pred_opt_node )
365
- self ._remove_collapsable_predecessors (opt_node )
366
- self ._update_op_node (opt_node , new_node )
401
+ self ._mark_predecessor (opt_node , pred_opt_node )
402
+ self ._replace_with_new_node (opt_node , new_node )
0 commit comments