diff --git a/mindsdb_sql/parser/ast/select/case.py b/mindsdb_sql/parser/ast/select/case.py index 02fa827..1ae0df0 100644 --- a/mindsdb_sql/parser/ast/select/case.py +++ b/mindsdb_sql/parser/ast/select/case.py @@ -4,13 +4,14 @@ class Case(ASTNode): - def __init__(self, rules, default=None, *args, **kwargs): + def __init__(self, rules, default=None, arg=None, *args, **kwargs): super().__init__(*args, **kwargs) # structure: # [ # [ condition, result ] # ] + self.arg = arg self.rules = rules self.default = default @@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs): if self.default is not None: default_str = f'{ind1}default => {self.default.to_string()}\n' + arg_str = '' + if self.arg is not None: + arg_str = f'{ind1}arg => {self.arg.to_string()}\n' + return f'{ind}Case(\n' \ + f'{arg_str}'\ f'{rules_str}\n' \ f'{default_str}' \ f'{ind})' @@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs): default_str = '' if self.default is not None: default_str = f' ELSE {self.default.to_string()}' - return f"CASE {rules_str}{default_str} END" + + arg_str = '' + if self.arg is not None: + arg_str = f'{self.arg.to_string()} ' + return f"CASE {arg_str}{rules_str}{default_str} END" diff --git a/mindsdb_sql/parser/ast/select/operation.py b/mindsdb_sql/parser/ast/select/operation.py index d7cad6a..e208d43 100644 --- a/mindsdb_sql/parser/ast/select/operation.py +++ b/mindsdb_sql/parser/ast/select/operation.py @@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs): class WindowFunction(ASTNode): - def __init__(self, function, partition=None, order_by=None, alias=None): + def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None): super().__init__() self.function = function self.partition = partition self.order_by = order_by self.alias = alias + self.modifier = modifier def to_tree(self, *args, level=0, **kwargs): fnc_str = self.function.to_tree(level=level+2) @@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs): alias_str = self.alias.to_string() else: alias_str = '' - return f'{fnc_str} over({partition_str} {order_str}) {alias_str}' + modifier_str = ' ' + self.modifier if self.modifier else '' + return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}' class Object(ASTNode): @@ -177,7 +179,12 @@ def __init__(self, info): super().__init__(op='interval', args=[info, ]) def get_string(self, *args, **kwargs): - return f'INTERVAL {self.args[0]}' + + arg = self.args[0] + items = arg.split(' ', maxsplit=1) + # quote first element + items[0] = f"'{items[0]}'" + return "INTERVAL " + " ".join(items) def to_tree(self, *args, level=0, **kwargs): return self.get_string( *args, **kwargs) diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 5186024..ae13959 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1352,6 +1352,15 @@ def column_list(self, p): def case(self, p): return Case(rules=p.case_conditions, default=getattr(p, 'expr', None)) + @_('CASE expr case_conditions ELSE expr END', + 'CASE expr case_conditions END') + def case(self, p): + if hasattr(p, 'expr'): + arg, default = p.expr, None + else: + arg, default = p.expr0, p.expr1 + return Case(rules=p.case_conditions, default=default, arg=arg) + @_('case_condition', 'case_conditions case_condition') def case_conditions(self, p): @@ -1364,13 +1373,18 @@ def case_condition(self, p): return [p.expr0, p.expr1] # Window function - @_('function OVER LPAREN window RPAREN') + @_('expr OVER LPAREN window RPAREN', + 'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN') def window_function(self, p): + modifier = None + if hasattr(p, 'BETWEEN'): + modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}' return WindowFunction( - function=p.function, + function=p.expr, order_by=p.window.get('order_by'), partition=p.window.get('partition'), + modifier=modifier, ) @_('window PARTITION_BY expr_list') diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index 05910ee..ff25b4f 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -229,6 +229,19 @@ def find_objects(node, is_table, **kwargs): mdb_entities.append(node) query_traversal(query, find_objects) + + # cte names are not mdb objects + if query.cte: + cte_names = [ + cte.name.parts[-1] + for cte in query.cte + ] + mdb_entities = [ + item + for item in mdb_entities + if '.'.join(item.parts) not in cte_names + ] + return { 'mdb_entities': mdb_entities, 'integrations': integrations, @@ -672,6 +685,16 @@ def plan_delete(self, query: Delete): )) def plan_cte(self, query): + query_info = self.get_query_info(query) + + if ( + len(query_info['integrations']) == 1 + and len(query_info['mdb_entities']) == 0 + and len(query_info['user_functions']) == 0 + ): + # single integration, will be planned later + return + for cte in query.cte: step = self.plan_select(cte.query) name = cte.name.parts[-1] diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index 745a8a9..929f262 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -293,10 +293,15 @@ def prepare_case(self, t: ast.Case): conditions.append( (self.to_expression(condition), self.to_expression(result)) ) + default = None if t.default is not None: - conditions.append(self.to_expression(t.default)) + default = self.to_expression(t.default) - return sa.case(*conditions) + value = None + if t.arg is not None: + value = self.to_expression(t.arg) + + return sa.case(*conditions, else_=default, value=value) def to_function(self, t): op = getattr(sa.func, t.op) diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 07f8abe..b7a6543 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -1026,6 +1026,40 @@ def test_case(self): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) + def test_case_simple_form(self): + sql = f'''SELECT + CASE R.DELETE_RULE + WHEN 'CASCADE' THEN 0 + WHEN 'SET NULL' THEN 2 + ELSE 3 + END AS DELETE_RULE + FROM COLLATIONS''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Case( + arg=Identifier('R.DELETE_RULE'), + rules=[ + [ + Constant('CASCADE'), + Constant(0) + ], + [ + Constant('SET NULL'), + Constant(2) + ] + ], + default=Constant(3), + alias=Identifier('DELETE_RULE') + ) + ], + from_table=Identifier('COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + def test_select_left(self): sql = f'select left(a, 1) from tab1' ast = parse_sql(sql) @@ -1152,3 +1186,23 @@ def test_table_double_quote(self): ast = parse_sql(sql) assert str(ast) == str(expected_ast) + + def test_window_function_mindsdb(self): + + # modifier + query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + partition=[Identifier('col1')], + order_by=[OrderBy(field=Identifier('col2'))], + modifier='rows BETWEEN unbounded preceding AND current row' + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_planner/test_integration_select.py b/tests/test_planner/test_integration_select.py index 5180465..faa8831 100644 --- a/tests/test_planner/test_integration_select.py +++ b/tests/test_planner/test_integration_select.py @@ -554,7 +554,7 @@ def test_select_from_table_subselect_api_integration(self): plan = plan_query( query, integrations=[{'name': 'int1', 'class_type': 'api', 'type': 'data'}], - predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}] + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], ) assert plan.steps == expected_plan.steps @@ -583,6 +583,47 @@ def test_select_from_table_subselect_sql_integration(self): assert plan.steps == expected_plan.steps + def test_select_from_single_integration(self): + sql_parsed = ''' + with tab2 as ( + select * from int1.tabl2 + ) + select x from tab2 + join int1.tab1 on 0=0 + where x1 in (select id from int1.tab1) + limit 1 + ''' + + sql_integration = ''' + with tab2 as ( + select * from tabl2 + ) + select x from tab2 + join tab1 on 0=0 + where x1 in (select id as id from tab1) + limit 1 + ''' + query = parse_sql(sql_parsed, dialect='mindsdb') + + expected_plan = QueryPlan( + predictor_namespace='mindsdb', + steps=[ + FetchDataframeStep( + integration='int1', + query=parse_sql(sql_integration), + ), + ], + ) + + plan = plan_query( + query, + integrations=[{'name': 'int1', 'class_type': 'sql', 'type': 'data'}], + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], + default_namespace='mindsdb', + ) + + assert plan.steps == expected_plan.steps + def test_delete_from_table_subselect_api_integration(self): query = parse_sql(''' delete from int1.tab1