Skip to content

Commit

Permalink
Merge pull request #404 from mindsdb/fix-case
Browse files Browse the repository at this point in the history
Fix case statement
  • Loading branch information
ea-rus authored Sep 20, 2024
2 parents 18fbd6b + 4b6ff7a commit 0f7a688
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
12 changes: 9 additions & 3 deletions mindsdb_sql/parser/ast/select/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Case(ASTNode):
def __init__(self, rules, default, *args, **kwargs):
def __init__(self, rules, default=None, *args, **kwargs):
super().__init__(*args, **kwargs)

# structure:
Expand Down Expand Up @@ -32,10 +32,13 @@ def to_tree(self, *args, level=0, **kwargs):
f'{ind1}{condition.to_string()} => {result.to_string()}'
)
rules_str = '\n'.join(rules_ar)
default_str = ''
if self.default is not None:
default_str = f'{ind1}default => {self.default.to_string()}\n'

return f'{ind}Case(\n' \
f'{rules_str}\n' \
f'{ind1}default => {self.default.to_string()}\n' \
f'{default_str}' \
f'{ind})'

def get_string(self, *args, alias=True, **kwargs):
Expand All @@ -47,4 +50,7 @@ def get_string(self, *args, alias=True, **kwargs):
)
rules_str = ' '.join(rules_ar)

return f"CASE {rules_str} ELSE {self.default.to_string()} END"
default_str = ''
if self.default is not None:
default_str = f' ELSE {self.default.to_string()}'
return f"CASE {rules_str}{default_str} END"
5 changes: 3 additions & 2 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,9 +1329,10 @@ def column_list(self, p):
return column_list

# case
@_('CASE case_conditions ELSE expr END')
@_('CASE case_conditions ELSE expr END',
'CASE case_conditions END')
def case(self, p):
return Case(rules=p.case_conditions, default=p.expr)
return Case(rules=p.case_conditions, default=getattr(p, 'expr', None))

@_('case_condition',
'case_conditions case_condition')
Expand Down
14 changes: 14 additions & 0 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,26 @@ def to_expression(self, t):
elif isinstance(t, ast.NotExists):
sub_stmt = self.prepare_select(t.query)
col = ~sub_stmt.exists()
elif isinstance(t, ast.Case):
col = self.prepare_case(t)
else:
# some other complex object?
raise NotImplementedError(f'Column {t}')

return col

def prepare_case(self, t: ast.Case):
conditions = []
for condition, result in t.rules:
conditions.append(
(self.to_expression(condition), self.to_expression(result))
)
else_ = None
if t.default is not None:
else_ = self.to_expression(t.default)

return sa.case(conditions, else_=else_)

def to_function(self, t):
op = getattr(sa.func, t.op)
if t.from_arg is not None:
Expand Down
2 changes: 0 additions & 2 deletions tests/test_parser/test_base_sql/test_select_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,6 @@ def test_case(self):
sum(
CASE
WHEN 1 = 1 THEN 1
ELSE 0
END
)
FROM INFORMATION_SCHEMA.COLLATIONS'''
Expand Down Expand Up @@ -1009,7 +1008,6 @@ def test_case(self):
Constant(1)
],
],
default=Constant(0)
)
]
)
Expand Down

0 comments on commit 0f7a688

Please sign in to comment.