Skip to content

Commit 0f7a688

Browse files
authored
Merge pull request #404 from mindsdb/fix-case
Fix case statement
2 parents 18fbd6b + 4b6ff7a commit 0f7a688

File tree

4 files changed

+26
-7
lines changed

4 files changed

+26
-7
lines changed

mindsdb_sql/parser/ast/select/case.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class Case(ASTNode):
7-
def __init__(self, rules, default, *args, **kwargs):
7+
def __init__(self, rules, default=None, *args, **kwargs):
88
super().__init__(*args, **kwargs)
99

1010
# structure:
@@ -32,10 +32,13 @@ def to_tree(self, *args, level=0, **kwargs):
3232
f'{ind1}{condition.to_string()} => {result.to_string()}'
3333
)
3434
rules_str = '\n'.join(rules_ar)
35+
default_str = ''
36+
if self.default is not None:
37+
default_str = f'{ind1}default => {self.default.to_string()}\n'
3538

3639
return f'{ind}Case(\n' \
3740
f'{rules_str}\n' \
38-
f'{ind1}default => {self.default.to_string()}\n' \
41+
f'{default_str}' \
3942
f'{ind})'
4043

4144
def get_string(self, *args, alias=True, **kwargs):
@@ -47,4 +50,7 @@ def get_string(self, *args, alias=True, **kwargs):
4750
)
4851
rules_str = ' '.join(rules_ar)
4952

50-
return f"CASE {rules_str} ELSE {self.default.to_string()} END"
53+
default_str = ''
54+
if self.default is not None:
55+
default_str = f' ELSE {self.default.to_string()}'
56+
return f"CASE {rules_str}{default_str} END"

mindsdb_sql/parser/dialects/mindsdb/parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,9 +1329,10 @@ def column_list(self, p):
13291329
return column_list
13301330

13311331
# case
1332-
@_('CASE case_conditions ELSE expr END')
1332+
@_('CASE case_conditions ELSE expr END',
1333+
'CASE case_conditions END')
13331334
def case(self, p):
1334-
return Case(rules=p.case_conditions, default=p.expr)
1335+
return Case(rules=p.case_conditions, default=getattr(p, 'expr', None))
13351336

13361337
@_('case_condition',
13371338
'case_conditions case_condition')

mindsdb_sql/render/sqlalchemy_render.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,26 @@ def to_expression(self, t):
279279
elif isinstance(t, ast.NotExists):
280280
sub_stmt = self.prepare_select(t.query)
281281
col = ~sub_stmt.exists()
282+
elif isinstance(t, ast.Case):
283+
col = self.prepare_case(t)
282284
else:
283285
# some other complex object?
284286
raise NotImplementedError(f'Column {t}')
285287

286288
return col
287289

290+
def prepare_case(self, t: ast.Case):
291+
conditions = []
292+
for condition, result in t.rules:
293+
conditions.append(
294+
(self.to_expression(condition), self.to_expression(result))
295+
)
296+
else_ = None
297+
if t.default is not None:
298+
else_ = self.to_expression(t.default)
299+
300+
return sa.case(conditions, else_=else_)
301+
288302
def to_function(self, t):
289303
op = getattr(sa.func, t.op)
290304
if t.from_arg is not None:

tests/test_parser/test_base_sql/test_select_structure.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,6 @@ def test_case(self):
971971
sum(
972972
CASE
973973
WHEN 1 = 1 THEN 1
974-
ELSE 0
975974
END
976975
)
977976
FROM INFORMATION_SCHEMA.COLLATIONS'''
@@ -1009,7 +1008,6 @@ def test_case(self):
10091008
Constant(1)
10101009
],
10111010
],
1012-
default=Constant(0)
10131011
)
10141012
]
10151013
)

0 commit comments

Comments
 (0)