Skip to content

Commit c39a968

Browse files
committed
fix new expression type
1 parent 5f9cda5 commit c39a968

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

sqlglot/optimizer/normalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
2828
Returns:
2929
sqlglot.Expression: normalized expression
3030
"""
31-
simplifier = Simplifier()
31+
simplifier = Simplifier(annotate_new_expressions=False)
3232

3333
for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
3434
if isinstance(node, exp.Connector):
@@ -173,7 +173,7 @@ def distributive_law(expression, dnf, max_distance, simplifier=None):
173173
from_func = exp.and_ if from_exp == exp.And else exp.or_
174174
to_func = exp.and_ if to_exp == exp.And else exp.or_
175175

176-
simplifier = simplifier or Simplifier()
176+
simplifier = simplifier or Simplifier(annotate_new_expressions=False)
177177

178178
if isinstance(a, to_exp) and isinstance(b, to_exp):
179179
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):

sqlglot/optimizer/simplify.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,19 @@ def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.E
8585
if new_expression is None:
8686
return new_expression
8787

88-
if new_expression != expression:
88+
if self.annotate_new_expressions and expression != new_expression:
8989
self._annotator.clear()
90+
91+
# We annotate this to ensure new children nodes are also annotated
9092
new_expression = self._annotator.annotate(
91-
expression=new_expression, annotate_scope=False
93+
expression=new_expression,
94+
annotate_scope=False,
9295
)
96+
97+
# Whatever expression the original expression is transformed into needs to preserve
98+
# the original type, otherwise the simplification could result in a different schema
99+
new_expression.type = expression.type
100+
93101
return new_expression
94102

95103
return _func
@@ -452,11 +460,9 @@ def boolean_literal(condition):
452460

453461

454462
class Simplifier:
455-
def __init__(
456-
self,
457-
dialect: DialectType = None,
458-
):
463+
def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
459464
self.dialect = Dialect.get_or_raise(dialect)
465+
self.annotate_new_expressions = annotate_new_expressions
460466

461467
self._annotator: TypeAnnotator = TypeAnnotator(
462468
schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False

0 commit comments

Comments
 (0)