Skip to content

Commit e060abf

Browse files
committed
Refine ArithmeticToEval related rules
1 parent dd02ba9 commit e060abf

File tree

3 files changed

+79
-78
lines changed

3 files changed

+79
-78
lines changed

mars/optimization/logical/core.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import functools
1515
import itertools
16-
import weakref
1716
from abc import ABC, abstractmethod
1817
from collections import defaultdict
1918
from dataclasses import dataclass
@@ -92,8 +91,6 @@ def get_original_entity(
9291

9392

9493
class OptimizationRule(ABC):
95-
_preds_to_remove = weakref.WeakKeyDictionary()
96-
9794
def __init__(
9895
self,
9996
graph: EntityGraph,
@@ -217,35 +214,6 @@ def _replace_subgraph(
217214
for result in new_results:
218215
self._graph.results[result_indices[result.key]] = result
219216

220-
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
221-
pred_original = self._records.get_original_entity(predecessor, predecessor)
222-
if predecessor not in self._preds_to_remove:
223-
self._preds_to_remove[pred_original] = {node}
224-
else:
225-
self._preds_to_remove[pred_original].add(node)
226-
227-
def _remove_collapsable_predecessors(self, node: EntityType):
228-
node = self._records.get_optimization_result(node) or node
229-
preds_opt_to_remove = []
230-
for pred in self._graph.predecessors(node):
231-
pred_original = self._records.get_original_entity(pred, pred)
232-
pred_opt = self._records.get_optimization_result(pred, pred)
233-
234-
if pred_opt in self._graph.results or pred_original in self._graph.results:
235-
continue
236-
affect_succ = self._preds_to_remove.get(pred_original) or []
237-
affect_succ_opt = [
238-
self._records.get_optimization_result(s, s) for s in affect_succ
239-
]
240-
if all(s in affect_succ_opt for s in self._graph.successors(pred)):
241-
preds_opt_to_remove.append((pred_original, pred_opt))
242-
243-
for pred_original, pred_opt in preds_opt_to_remove:
244-
self._graph.remove_node(pred_opt)
245-
self._records.append_record(
246-
OptimizationRecord(pred_original, None, OptimizationRecordType.delete)
247-
)
248-
249217

250218
class OperandBasedOptimizationRule(OptimizationRule):
251219
"""

mars/optimization/logical/tests/test_core.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,7 @@ def test_replace_null_subgraph():
157157

158158
c1.inputs.clear()
159159
c2.inputs.clear()
160-
r.replace_subgraph(
161-
None,
162-
{key_to_node[op.key] for op in [s1, s2]}
163-
)
160+
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
164161
assert g1.results == expected_results
165162
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
166163
expected_edges = {

mars/optimization/logical/tileable/arithmetic_query.py

Lines changed: 78 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,22 @@
1313
# limitations under the License.
1414

1515
import weakref
16-
from typing import NamedTuple, Optional
16+
from abc import ABC
17+
from typing import NamedTuple, Optional, Type, Set
1718

1819
import numpy as np
1920
from pandas.api.types import is_scalar
2021

2122
from .... import dataframe as md
22-
from ....core import Tileable, get_output_types, ENTITY_TYPE
23+
from ....core import Tileable, get_output_types, ENTITY_TYPE, TileableGraph
24+
from ....core.graph import EntityGraph
2325
from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc
2426
from ....dataframe.base.eval import DataFrameEval
2527
from ....dataframe.indexing.getitem import DataFrameIndex
2628
from ....dataframe.indexing.setitem import DataFrameSetitem
27-
from ....typing import OperandType
29+
from ....typing import OperandType, EntityType
2830
from ....utils import implements
29-
from ..core import OptimizationRecord, OptimizationRecordType
31+
from ..core import OptimizationRecord, OptimizationRecordType, OptimizationRecords
3032
from ..tileable.core import register_operand_based_optimization_rule
3133
from .core import OperandBasedOptimizationRule
3234

@@ -66,8 +68,70 @@ def builder(lhs: str, rhs: str):
6668
_extract_result_cache = weakref.WeakKeyDictionary()
6769

6870

71+
class _EvalRewriteOptimizationRule(OperandBasedOptimizationRule, ABC):
72+
def __init__(
73+
self,
74+
graph: EntityGraph,
75+
records: OptimizationRecords,
76+
optimizer_cls: Type["Optimizer"],
77+
):
78+
super().__init__(graph, records, optimizer_cls)
79+
self._marked_predecessors = dict()
80+
81+
def _mark_predecessor(self, node: EntityType, predecessor: EntityType):
82+
pred_original = self._records.get_original_entity(predecessor, predecessor)
83+
if predecessor not in self._marked_predecessors:
84+
self._marked_predecessors[pred_original] = {node}
85+
else:
86+
self._marked_predecessors[pred_original].add(node)
87+
88+
def _find_nodes_to_remove(self, node: EntityType) -> Set[EntityType]:
89+
node = self._records.get_optimization_result(node) or node
90+
removed_nodes = {node}
91+
results_set = set(self._graph.results)
92+
removed_pairs = []
93+
for pred in self._graph.iter_predecessors(node):
94+
pred_original = self._records.get_original_entity(pred, pred)
95+
pred_opt = self._records.get_optimization_result(pred, pred)
96+
97+
if pred_opt in results_set or pred_original in results_set:
98+
continue
99+
100+
affect_succ = self._marked_predecessors.get(pred_original) or []
101+
affect_succ_opt = [
102+
self._records.get_optimization_result(s, s) for s in affect_succ
103+
]
104+
if all(s in affect_succ_opt for s in self._graph.iter_successors(pred)):
105+
removed_pairs.append((pred_original, pred_opt))
106+
107+
for pred_original, pred_opt in removed_pairs:
108+
removed_nodes.add(pred_opt)
109+
self._records.append_record(
110+
OptimizationRecord(pred_original, None, OptimizationRecordType.delete)
111+
)
112+
return removed_nodes
113+
114+
def _replace_with_new_node(self, original_node: EntityType, new_node: EntityType):
115+
# Find all the nodes to remove
116+
nodes_to_remove = self._find_nodes_to_remove(original_node)
117+
118+
# Build the replaced subgraph
119+
subgraph = TileableGraph()
120+
subgraph.add_node(new_node)
121+
122+
new_results = [new_node] if new_node in self._graph.results else None
123+
self._replace_subgraph(subgraph, nodes_to_remove, new_results)
124+
self._records.append_record(
125+
OptimizationRecord(
126+
self._records.get_original_entity(original_node, original_node),
127+
new_node,
128+
OptimizationRecordType.replace,
129+
)
130+
)
131+
132+
69133
@register_operand_based_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc])
70-
class SeriesArithmeticToEval(OperandBasedOptimizationRule):
134+
class SeriesArithmeticToEval(_EvalRewriteOptimizationRule):
71135
_var_counter = 0
72136

73137
@classmethod
@@ -151,7 +215,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord:
151215
if in_tileable is None:
152216
return EvalExtractRecord()
153217

154-
self._add_collapsable_predecessor(tileable, op.inputs[0])
218+
self._mark_predecessor(tileable, op.inputs[0])
155219
return EvalExtractRecord(
156220
in_tileable, _func_name_to_builder[func_name](expr), variables
157221
)
@@ -164,10 +228,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord:
164228

165229
lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs)
166230
if lhs_tileable is not None:
167-
self._add_collapsable_predecessor(tileable, op.lhs)
231+
self._mark_predecessor(tileable, op.lhs)
168232
rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs)
169233
if rhs_tileable is not None:
170-
self._add_collapsable_predecessor(tileable, op.rhs)
234+
self._mark_predecessor(tileable, op.rhs)
171235

172236
if lhs_expr is None or rhs_expr is None:
173237
return EvalExtractRecord()
@@ -204,24 +268,10 @@ def apply_to_operand(self, op: OperandType):
204268
new_node = new_op.new_tileable(
205269
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
206270
).data
271+
self._replace_with_new_node(node, new_node)
207272

208-
self._remove_collapsable_predecessors(node)
209-
self._replace_node(node, new_node)
210-
self._graph.add_edge(opt_in_tileable, new_node)
211273

212-
self._records.append_record(
213-
OptimizationRecord(node, new_node, OptimizationRecordType.replace)
214-
)
215-
216-
# check node if it's in result
217-
try:
218-
i = self._graph.results.index(node)
219-
self._graph.results[i] = new_node
220-
except ValueError:
221-
pass
222-
223-
224-
class _DataFrameEvalRewriteRule(OperandBasedOptimizationRule):
274+
class _DataFrameEvalRewriteRule(_EvalRewriteOptimizationRule):
225275
@implements(OperandBasedOptimizationRule.match_operand)
226276
def match_operand(self, op: OperandType) -> bool:
227277
optimized_eval_op = self._get_optimized_eval_op(op)
@@ -245,16 +295,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
245295
def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
246296
raise NotImplementedError
247297

248-
def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE):
249-
self._replace_node(old_node, new_node)
250-
for in_tileable in new_node.inputs:
251-
self._graph.add_edge(in_tileable, new_node)
252-
253-
original_node = self._records.get_original_entity(old_node, old_node)
254-
self._records.append_record(
255-
OptimizationRecord(original_node, new_node, OptimizationRecordType.replace)
256-
)
257-
258298
@implements(OperandBasedOptimizationRule.apply_to_operand)
259299
def apply_to_operand(self, op: DataFrameIndex):
260300
node = op.outputs[0]
@@ -268,10 +308,8 @@ def apply_to_operand(self, op: DataFrameIndex):
268308
new_node = new_op.new_tileable(
269309
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
270310
).data
271-
272-
self._add_collapsable_predecessor(node, in_columnar_node)
273-
self._remove_collapsable_predecessors(node)
274-
self._update_op_node(node, new_node)
311+
self._mark_predecessor(node, in_columnar_node)
312+
self._replace_with_new_node(node, new_node)
275313

276314

277315
@register_operand_based_optimization_rule([DataFrameIndex])
@@ -360,7 +398,5 @@ def apply_to_operand(self, op: DataFrameIndex):
360398
new_node = new_op.new_tileable(
361399
pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params
362400
).data
363-
364-
self._add_collapsable_predecessor(opt_node, pred_opt_node)
365-
self._remove_collapsable_predecessors(opt_node)
366-
self._update_op_node(opt_node, new_node)
401+
self._mark_predecessor(opt_node, pred_opt_node)
402+
self._replace_with_new_node(opt_node, new_node)

0 commit comments

Comments
 (0)