Skip to content

Commit 5f9cda5

Browse files
committed
Remove unnecessary Simplifier() instantiations in simplify.py
1 parent 4c9c765 commit 5f9cda5

File tree

3 files changed

+59
-63
lines changed

3 files changed

+59
-63
lines changed

sqlglot/optimizer/annotate_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def annotate_types(
5656
expression_metadata: Maps expression type to corresponding annotation function.
5757
coerces_to: Maps expression type to set of types that it can be coerced into.
5858
overwrite_types: Re-annotate the existing AST types.
59+
5960
Returns:
6061
The expression annotated with types.
6162
"""

sqlglot/optimizer/normalize.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlglot.errors import OptimizeError
77
from sqlglot.helper import while_changing
88
from sqlglot.optimizer.scope import find_all_in_scope
9-
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
9+
from sqlglot.optimizer.simplify import Simplifier, flatten
1010

1111
logger = logging.getLogger("sqlglot")
1212

@@ -28,14 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
2828
Returns:
2929
sqlglot.Expression: normalized expression
3030
"""
31+
simplifier = Simplifier()
32+
3133
for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
3234
if isinstance(node, exp.Connector):
3335
if normalized(node, dnf=dnf):
3436
continue
3537
root = node is expression
3638
original = node.copy()
3739

38-
node.transform(rewrite_between, copy=False)
40+
node.transform(simplifier.rewrite_between, copy=False)
3941
distance = normalization_distance(node, dnf=dnf, max_=max_distance)
4042

4143
if distance > max_distance:
@@ -46,7 +48,10 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
4648

4749
try:
4850
node = node.replace(
49-
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
51+
while_changing(
52+
node,
53+
lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
54+
)
5055
)
5156
except OptimizeError as e:
5257
logger.info(e)
@@ -146,7 +151,7 @@ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
146151
yield from _predicate_lengths(right, dnf, max_, depth)
147152

148153

149-
def distributive_law(expression, dnf, max_distance):
154+
def distributive_law(expression, dnf, max_distance, simplifier=None):
150155
"""
151156
x OR (y AND z) -> (x OR y) AND (x OR z)
152157
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -168,32 +173,34 @@ def distributive_law(expression, dnf, max_distance):
168173
from_func = exp.and_ if from_exp == exp.And else exp.or_
169174
to_func = exp.and_ if to_exp == exp.And else exp.or_
170175

176+
simplifier = simplifier or Simplifier()
177+
171178
if isinstance(a, to_exp) and isinstance(b, to_exp):
172179
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
173-
return _distribute(a, b, from_func, to_func)
174-
return _distribute(b, a, from_func, to_func)
180+
return _distribute(a, b, from_func, to_func, simplifier)
181+
return _distribute(b, a, from_func, to_func, simplifier)
175182
if isinstance(a, to_exp):
176-
return _distribute(b, a, from_func, to_func)
183+
return _distribute(b, a, from_func, to_func, simplifier)
177184
if isinstance(b, to_exp):
178-
return _distribute(a, b, from_func, to_func)
185+
return _distribute(a, b, from_func, to_func, simplifier)
179186

180187
return expression
181188

182189

183-
def _distribute(a, b, from_func, to_func):
190+
def _distribute(a, b, from_func, to_func, simplifier):
184191
if isinstance(a, exp.Connector):
185192
exp.replace_children(
186193
a,
187194
lambda c: to_func(
188-
uniq_sort(flatten(from_func(c, b.left))),
189-
uniq_sort(flatten(from_func(c, b.right))),
195+
simplifier.uniq_sort(flatten(from_func(c, b.left))),
196+
simplifier.uniq_sort(flatten(from_func(c, b.right))),
190197
copy=False,
191198
),
192199
)
193200
else:
194201
a = to_func(
195-
uniq_sort(flatten(from_func(a, b.left))),
196-
uniq_sort(flatten(from_func(a, b.right))),
202+
simplifier.uniq_sort(flatten(from_func(a, b.left))),
203+
simplifier.uniq_sort(flatten(from_func(a, b.right))),
197204
copy=False,
198205
)
199206

sqlglot/optimizer/simplify.py

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,47 @@ def flatten(expression):
108108
return expression
109109

110110

111-
# Backward compatibility wrappers for functions used by other modules
112-
def rewrite_between(expression: exp.Expression) -> exp.Expression:
113-
return Simplifier().rewrite_between(expression)
111+
def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
112+
if not isinstance(expression, exp.Paren):
113+
return expression
114114

115+
this = expression.this
116+
parent = expression.parent
117+
parent_is_predicate = isinstance(parent, exp.Predicate)
115118

116-
def uniq_sort(expression: exp.Expression, root: bool = True) -> exp.Expression:
117-
return Simplifier().uniq_sort(expression, root)
119+
if isinstance(this, exp.Select):
120+
return expression
118121

122+
if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
123+
return expression
119124

120-
def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
121-
return Simplifier(dialect=dialect).simplify_parens(expression, dialect)
125+
# Handle risingwave struct columns
126+
# see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
127+
if (
128+
dialect == "risingwave"
129+
and isinstance(parent, exp.Dot)
130+
and (isinstance(parent.right, (exp.Identifier, exp.Star)))
131+
):
132+
return expression
133+
134+
if (
135+
not isinstance(parent, (exp.Condition, exp.Binary))
136+
or isinstance(parent, exp.Paren)
137+
or (
138+
not isinstance(this, exp.Binary)
139+
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
140+
)
141+
or (
142+
isinstance(this, exp.Predicate)
143+
and not (parent_is_predicate or isinstance(parent, exp.Neg))
144+
)
145+
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
146+
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
147+
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
148+
):
149+
return this
150+
151+
return expression
122152

123153

124154
def propagate_constants(expression, root=True):
@@ -594,7 +624,7 @@ def _simplify(expression):
594624

595625
new_node = self.simplify_literals(new_node, root)
596626
new_node = self.simplify_equality(new_node)
597-
new_node = self.simplify_parens(new_node, dialect=self.dialect)
627+
new_node = simplify_parens(new_node, dialect=self.dialect)
598628
new_node = self.simplify_datetrunc(new_node)
599629
new_node = self.sort_comparison(new_node)
600630
new_node = self.simplify_startswith(new_node)
@@ -1067,48 +1097,6 @@ def _simplify_binary(self, expression, a, b):
10671097

10681098
return None
10691099

1070-
def simplify_parens(self, expression: exp.Expression, dialect: DialectType) -> exp.Expression:
1071-
if not isinstance(expression, exp.Paren):
1072-
return expression
1073-
1074-
this = expression.this
1075-
parent = expression.parent
1076-
parent_is_predicate = isinstance(parent, exp.Predicate)
1077-
1078-
if isinstance(this, exp.Select):
1079-
return expression
1080-
1081-
if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
1082-
return expression
1083-
1084-
# Handle risingwave struct columns
1085-
# see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
1086-
if (
1087-
dialect == "risingwave"
1088-
and isinstance(parent, exp.Dot)
1089-
and (isinstance(parent.right, (exp.Identifier, exp.Star)))
1090-
):
1091-
return expression
1092-
1093-
if (
1094-
not isinstance(parent, (exp.Condition, exp.Binary))
1095-
or isinstance(parent, exp.Paren)
1096-
or (
1097-
not isinstance(this, exp.Binary)
1098-
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
1099-
)
1100-
or (
1101-
isinstance(this, exp.Predicate)
1102-
and not (parent_is_predicate or isinstance(parent, exp.Neg))
1103-
)
1104-
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
1105-
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
1106-
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
1107-
):
1108-
return this
1109-
1110-
return expression
1111-
11121100
@annotate_types_on_change
11131101
def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression:
11141102
# COALESCE(x) -> x

0 commit comments

Comments
 (0)