diff --git a/mindsdb_sql/parser/ast/select/case.py b/mindsdb_sql/parser/ast/select/case.py index a4e6fc1..02fa827 100644 --- a/mindsdb_sql/parser/ast/select/case.py +++ b/mindsdb_sql/parser/ast/select/case.py @@ -4,7 +4,7 @@ class Case(ASTNode): - def __init__(self, rules, default, *args, **kwargs): + def __init__(self, rules, default=None, *args, **kwargs): super().__init__(*args, **kwargs) # structure: @@ -32,10 +32,13 @@ def to_tree(self, *args, level=0, **kwargs): f'{ind1}{condition.to_string()} => {result.to_string()}' ) rules_str = '\n'.join(rules_ar) + default_str = '' + if self.default is not None: + default_str = f'{ind1}default => {self.default.to_string()}\n' return f'{ind}Case(\n' \ f'{rules_str}\n' \ - f'{ind1}default => {self.default.to_string()}\n' \ + f'{default_str}' \ f'{ind})' def get_string(self, *args, alias=True, **kwargs): @@ -47,4 +50,7 @@ def get_string(self, *args, alias=True, **kwargs): ) rules_str = ' '.join(rules_ar) - return f"CASE {rules_str} ELSE {self.default.to_string()} END" + default_str = '' + if self.default is not None: + default_str = f' ELSE {self.default.to_string()}' + return f"CASE {rules_str}{default_str} END" diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 6624212..9e73355 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1329,9 +1329,10 @@ def column_list(self, p): return column_list # case - @_('CASE case_conditions ELSE expr END') + @_('CASE case_conditions ELSE expr END', + 'CASE case_conditions END') def case(self, p): - return Case(rules=p.case_conditions, default=p.expr) + return Case(rules=p.case_conditions, default=getattr(p, 'expr', None)) @_('case_condition', 'case_conditions case_condition') diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index df08a26..c0d0c8a 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -279,12 +279,26 @@ def to_expression(self, t): elif isinstance(t, ast.NotExists): sub_stmt = self.prepare_select(t.query) col = ~sub_stmt.exists() + elif isinstance(t, ast.Case): + col = self.prepare_case(t) else: # some other complex object? raise NotImplementedError(f'Column {t}') return col + def prepare_case(self, t: ast.Case): + conditions = [] + for condition, result in t.rules: + conditions.append( + (self.to_expression(condition), self.to_expression(result)) + ) + else_ = None + if t.default is not None: + else_ = self.to_expression(t.default) + + return sa.case(conditions, else_=else_) + def to_function(self, t): op = getattr(sa.func, t.op) if t.from_arg is not None: diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 181281b..8f9df64 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -971,7 +971,6 @@ def test_case(self): sum( CASE WHEN 1 = 1 THEN 1 - ELSE 0 END ) FROM INFORMATION_SCHEMA.COLLATIONS''' @@ -1009,7 +1008,6 @@ def test_case(self): Constant(1) ], ], - default=Constant(0) ) ] )