diff --git a/mindsdb_sql/parser/ast/select/__init__.py b/mindsdb_sql/parser/ast/select/__init__.py index d064242..7dd64b6 100644 --- a/mindsdb_sql/parser/ast/select/__init__.py +++ b/mindsdb_sql/parser/ast/select/__init__.py @@ -1,6 +1,6 @@ from .select import Select from .common_table_expression import CommonTableExpression -from .union import Union +from .union import Union, Except, Intersect from .constant import Constant, NullConstant, Last from .star import Star from .identifier import Identifier diff --git a/mindsdb_sql/parser/ast/select/case.py b/mindsdb_sql/parser/ast/select/case.py index 02fa827..1ae0df0 100644 --- a/mindsdb_sql/parser/ast/select/case.py +++ b/mindsdb_sql/parser/ast/select/case.py @@ -4,13 +4,14 @@ class Case(ASTNode): - def __init__(self, rules, default=None, *args, **kwargs): + def __init__(self, rules, default=None, arg=None, *args, **kwargs): super().__init__(*args, **kwargs) # structure: # [ # [ condition, result ] # ] + self.arg = arg self.rules = rules self.default = default @@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs): if self.default is not None: default_str = f'{ind1}default => {self.default.to_string()}\n' + arg_str = '' + if self.arg is not None: + arg_str = f'{ind1}arg => {self.arg.to_string()}\n' + return f'{ind}Case(\n' \ + f'{arg_str}'\ f'{rules_str}\n' \ f'{default_str}' \ f'{ind})' @@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs): default_str = '' if self.default is not None: default_str = f' ELSE {self.default.to_string()}' - return f"CASE {rules_str}{default_str} END" + + arg_str = '' + if self.arg is not None: + arg_str = f'{self.arg.to_string()} ' + return f"CASE {arg_str}{rules_str}{default_str} END" diff --git a/mindsdb_sql/parser/ast/select/operation.py b/mindsdb_sql/parser/ast/select/operation.py index d7cad6a..e208d43 100644 --- a/mindsdb_sql/parser/ast/select/operation.py +++ b/mindsdb_sql/parser/ast/select/operation.py @@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs): class WindowFunction(ASTNode): - def __init__(self, function, partition=None, order_by=None, alias=None): + def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None): super().__init__() self.function = function self.partition = partition self.order_by = order_by self.alias = alias + self.modifier = modifier def to_tree(self, *args, level=0, **kwargs): fnc_str = self.function.to_tree(level=level+2) @@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs): alias_str = self.alias.to_string() else: alias_str = '' - return f'{fnc_str} over({partition_str} {order_str}) {alias_str}' + modifier_str = ' ' + self.modifier if self.modifier else '' + return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}' class Object(ASTNode): @@ -177,7 +179,12 @@ def __init__(self, info): super().__init__(op='interval', args=[info, ]) def get_string(self, *args, **kwargs): - return f'INTERVAL {self.args[0]}' + + arg = self.args[0] + items = arg.split(' ', maxsplit=1) + # quote first element + items[0] = f"'{items[0]}'" + return "INTERVAL " + " ".join(items) def to_tree(self, *args, level=0, **kwargs): return self.get_string( *args, **kwargs) diff --git a/mindsdb_sql/parser/ast/select/union.py b/mindsdb_sql/parser/ast/select/union.py index e78609e..dce1da2 100644 --- a/mindsdb_sql/parser/ast/select/union.py +++ b/mindsdb_sql/parser/ast/select/union.py @@ -2,7 +2,8 @@ from mindsdb_sql.parser.utils import indent -class Union(ASTNode): +class CombiningQuery(ASTNode): + operation = None def __init__(self, left, @@ -24,7 +25,8 @@ def to_tree(self, *args, level=0, **kwargs): left_str = f'\n{ind1}left=\n{self.left.to_tree(level=level + 2)},' right_str = f'\n{ind1}right=\n{self.right.to_tree(level=level + 2)},' - out_str = f'{ind}Union(unique={repr(self.unique)},' \ + cls_name = self.__class__.__name__ + out_str = f'{ind}{cls_name}(unique={repr(self.unique)},' \ f'{left_str}' \ f'{right_str}' \ f'\n{ind})' @@ -33,7 +35,21 @@ def to_tree(self, *args, level=0, **kwargs): def get_string(self, *args, **kwargs): left_str = str(self.left) right_str = str(self.right) - keyword = 'UNION' if self.unique else 'UNION ALL' + keyword = self.operation + if not self.unique: + keyword += ' ALL' out_str = f"""{left_str}\n{keyword}\n{right_str}""" return out_str + + +class Union(CombiningQuery): + operation = 'UNION' + + +class Intersect(CombiningQuery): + operation = 'INTERSECT' + + +class Except(CombiningQuery): + operation = 'EXCEPT' diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 01024c9..9232efb 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -55,7 +55,7 @@ class MindsDBLexer(Lexer): JOIN, INNER, OUTER, CROSS, LEFT, RIGHT, ON, - UNION, ALL, + UNION, ALL, INTERSECT, EXCEPT, # CASE CASE, ELSE, END, THEN, WHEN, @@ -238,6 +238,8 @@ class MindsDBLexer(Lexer): # UNION UNION = r'\bUNION\b' + INTERSECT = r'\bINTERSECT\b' + EXCEPT = r'\bEXCEPT\b' ALL = r'\bALL\b' # CASE diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index c740cf4..ae13959 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -70,6 +70,7 @@ class MindsDBParser(Parser): 'drop_dataset', 'select', 'insert', + 'union', 'update', 'delete', 'evaluate', @@ -614,10 +615,13 @@ def update(self, p): # INSERT @_('INSERT INTO identifier LPAREN column_list RPAREN select', - 'INSERT INTO identifier select') + 'INSERT INTO identifier LPAREN column_list RPAREN union', + 'INSERT INTO identifier select', + 'INSERT INTO identifier union') def insert(self, p): columns = getattr(p, 'column_list', None) - return Insert(table=p.identifier, columns=columns, from_select=p.select) + query = p.select if hasattr(p, 'select') else p.union + return Insert(table=p.identifier, columns=columns, from_select=query) @_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set', 'INSERT INTO identifier VALUES expr_list_set') @@ -998,19 +1002,35 @@ def database_engine(self, p): engine = p.string return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty} - # UNION / UNION ALL - @_('select UNION select') - def select(self, p): - return Union(left=p[0], right=p[2], unique=True) - - @_('select UNION ALL select') - def select(self, p): - return Union(left=p[0], right=p[3], unique=False) + # Combining + @_('select UNION select', + 'union UNION select', + 'select UNION ALL select', + 'union UNION ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Union(left=p[0], right=p[2] if unique else p[3], unique=unique) + + @_('select INTERSECT select', + 'union INTERSECT select', + 'select INTERSECT ALL select', + 'union INTERSECT ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Intersect(left=p[0], right=p[2] if unique else p[3], unique=unique) + @_('select EXCEPT select', + 'union EXCEPT select', + 'select EXCEPT ALL select', + 'union EXCEPT ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Except(left=p[0], right=p[2] if unique else p[3], unique=unique) # tableau @_('LPAREN select RPAREN') + @_('LPAREN union RPAREN') def select(self, p): - return p.select + return p[1] # WITH @_('ctes select') @@ -1030,13 +1050,14 @@ def ctes(self, p): ] return ctes - @_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN') + @_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN', + 'WITH identifier cte_columns_or_nothing AS LPAREN union RPAREN') def ctes(self, p): return [ CommonTableExpression( name=p.identifier, columns=p.cte_columns_or_nothing, - query=p.select) + query=p[5]) ] @_('empty') @@ -1331,6 +1352,15 @@ def column_list(self, p): def case(self, p): return Case(rules=p.case_conditions, default=getattr(p, 'expr', None)) + @_('CASE expr case_conditions ELSE expr END', + 'CASE expr case_conditions END') + def case(self, p): + if hasattr(p, 'expr'): + arg, default = p.expr, None + else: + arg, default = p.expr0, p.expr1 + return Case(rules=p.case_conditions, default=default, arg=arg) + @_('case_condition', 'case_conditions case_condition') def case_conditions(self, p): @@ -1343,13 +1373,18 @@ def case_condition(self, p): return [p.expr0, p.expr1] # Window function - @_('function OVER LPAREN window RPAREN') + @_('expr OVER LPAREN window RPAREN', + 'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN') def window_function(self, p): + modifier = None + if hasattr(p, 'BETWEEN'): + modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}' return WindowFunction( - function=p.function, + function=p.expr, order_by=p.window.get('order_by'), partition=p.window.get('partition'), + modifier=modifier, ) @_('window PARTITION_BY expr_list') diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index f2d08d1..35bea81 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -3,7 +3,7 @@ from mindsdb_sql.exceptions import PlanningException from mindsdb_sql.parser import ast from mindsdb_sql.parser.ast import (Select, Identifier, Join, Star, BinaryOperation, Constant, Union, CreateTable, - Function, Insert, + Function, Insert, Except, Intersect, Update, NativeQuery, Parameter, Delete) from mindsdb_sql.planner import utils from mindsdb_sql.planner.query_plan import QueryPlan @@ -229,6 +229,19 @@ def find_objects(node, is_table, **kwargs): mdb_entities.append(node) query_traversal(query, find_objects) + + # cte names are not mdb objects + if isinstance(query, Select) and query.cte: + cte_names = [ + cte.name.parts[-1] + for cte in query.cte + ] + mdb_entities = [ + item + for item in mdb_entities + if '.'.join(item.parts) not in cte_names + ] + return { 'mdb_entities': mdb_entities, 'integrations': integrations, @@ -258,21 +271,21 @@ def find_selects(node, **kwargs): return find_selects def plan_select_identifier(self, query): - query_info = self.get_query_info(query) - - if len(query_info['integrations']) == 0 and len(query_info['predictors']) >= 1: - # select from predictor - return self.plan_select_from_predictor(query) - elif ( - len(query_info['integrations']) == 1 - and len(query_info['mdb_entities']) == 0 - and len(query_info['user_functions']) == 0 - ): - - int_name = list(query_info['integrations'])[0] - if self.integrations.get(int_name, {}).get('class_type') != 'api': - # one integration without predictors, send all query to integration - return self.plan_integration_select(query) + # query_info = self.get_query_info(query) + # + # if len(query_info['integrations']) == 0 and len(query_info['predictors']) >= 1: + # # select from predictor + # return self.plan_select_from_predictor(query) + # elif ( + # len(query_info['integrations']) == 1 + # and len(query_info['mdb_entities']) == 0 + # and len(query_info['user_functions']) == 0 + # ): + # + # int_name = list(query_info['integrations'])[0] + # if self.integrations.get(int_name, {}).get('class_type') != 'api': + # # one integration without predictors, send all query to integration + # return self.plan_integration_select(query) # find subselects main_integration, _ = self.resolve_database_table(query.from_table) @@ -367,21 +380,21 @@ def plan_api_db_select(self, query): def plan_nested_select(self, select): - query_info = self.get_query_info(select) - # get all predictors - - if ( - len(query_info['mdb_entities']) == 0 - and len(query_info['integrations']) == 1 - and len(query_info['user_functions']) == 0 - and 'files' not in query_info['integrations'] - and 'views' not in query_info['integrations'] - ): - int_name = list(query_info['integrations'])[0] - if self.integrations.get(int_name, {}).get('class_type') != 'api': - - # if no predictor inside = run as is - return self.plan_integration_nested_select(select, int_name) + # query_info = self.get_query_info(select) + # # get all predictors + # + # if ( + # len(query_info['mdb_entities']) == 0 + # and len(query_info['integrations']) == 1 + # and len(query_info['user_functions']) == 0 + # and 'files' not in query_info['integrations'] + # and 'views' not in query_info['integrations'] + # ): + # int_name = list(query_info['integrations'])[0] + # if self.integrations.get(int_name, {}).get('class_type') != 'api': + # + # # if no predictor inside = run as is + # return self.plan_integration_nested_select(select, int_name) return self.plan_mdb_nested_select(select) @@ -672,13 +685,39 @@ def plan_delete(self, query: Delete): )) def plan_cte(self, query): + for cte in query.cte: step = self.plan_select(cte.query) name = cte.name.parts[-1] self.cte_results[name] = step.result + def check_single_integration(self, query): + query_info = self.get_query_info(query) + + # can we send all query to integration? + + # one integration and not mindsdb objects in query + if ( + len(query_info['mdb_entities']) == 0 + and len(query_info['integrations']) == 1 + and 'files' not in query_info['integrations'] + and 'views' not in query_info['integrations'] + and len(query_info['user_functions']) == 0 + ): + + int_name = list(query_info['integrations'])[0] + # if is sql database + if self.integrations.get(int_name, {}).get('class_type') != 'api': + + # send to this integration + self.prepare_integration_select(int_name, query) + + last_step = self.plan.add_step(FetchDataframeStep(integration=int_name, query=query)) + return last_step + def plan_select(self, query, integration=None): - if isinstance(query, Union): + + if isinstance(query, (Union, Except, Intersect)): return self.plan_union(query, integration=integration) if query.cte is not None: @@ -734,14 +773,15 @@ def plan_sub_select(self, query, prev_step, add_absent_cols=False): return prev_step def plan_union(self, query, integration=None): - if isinstance(query.left, Union): - step1 = self.plan_union(query.left, integration=integration) - else: - # it is select - step1 = self.plan_select(query.left, integration=integration) + step1 = self.plan_select(query.left, integration=integration) step2 = self.plan_select(query.right, integration=integration) + operation = 'union' + if isinstance(query, Except): + operation = 'except' + elif isinstance(query, Intersect): + operation = 'intersect' - return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique)) + return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique, operation=operation)) # method for compatibility def from_query(self, query=None): @@ -750,7 +790,9 @@ def from_query(self, query=None): if query is None: query = self.query - if isinstance(query, (Select, Union)): + if isinstance(query, (Select, Union, Except, Intersect)): + if self.check_single_integration(query): + return self.plan self.plan_select(query) elif isinstance(query, CreateTable): self.plan_create_table(query) diff --git a/mindsdb_sql/planner/steps.py b/mindsdb_sql/planner/steps.py index 40e86e5..395f307 100644 --- a/mindsdb_sql/planner/steps.py +++ b/mindsdb_sql/planner/steps.py @@ -75,11 +75,12 @@ def __init__(self, left, right, query, *args, **kwargs): class UnionStep(PlanStep): """Union of two dataframes, producing a new dataframe""" - def __init__(self, left, right, unique, *args, **kwargs): + def __init__(self, left, right, unique, operation='union', *args, **kwargs): super().__init__(*args, **kwargs) self.left = left self.right = right self.unique = unique + self.operation = operation # TODO remove diff --git a/mindsdb_sql/planner/utils.py b/mindsdb_sql/planner/utils.py index e4d5278..9baa16e 100644 --- a/mindsdb_sql/planner/utils.py +++ b/mindsdb_sql/planner/utils.py @@ -145,7 +145,7 @@ def query_traversal(node, callback, is_table=False, is_target=False, parent_quer array.append(node_out) node.order_by = array - elif isinstance(node, ast.Union): + elif isinstance(node, (ast.Union, ast.Intersect, ast.Except)): node_out = query_traversal(node.left, callback, parent_query=node) if node_out is not None: node.left = node_out diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index bcd54f4..c4c31aa 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -293,10 +293,15 @@ def prepare_case(self, t: ast.Case): conditions.append( (self.to_expression(condition), self.to_expression(result)) ) + default = None if t.default is not None: - conditions.append(self.to_expression(t.default)) + default = self.to_expression(t.default) - return sa.case(*conditions) + value = None + if t.arg is not None: + value = self.to_expression(t.arg) + + return sa.case(*conditions, else_=default, value=value) def to_function(self, t): op = getattr(sa.func, t.op) @@ -382,7 +387,7 @@ def to_table(self, node): if node.alias: table = aliased(table, name=self.get_alias(node.alias)) - elif isinstance(node, ast.Select): + elif isinstance(node, (ast.Select, ast.Union, ast.Intersect, ast.Except)): sub_stmt = self.prepare_select(node) alias = None if node.alias: @@ -396,7 +401,7 @@ def to_table(self, node): return table def prepare_select(self, node): - if isinstance(node, ast.Union): + if isinstance(node, (ast.Union, ast.Except, ast.Intersect)): return self.prepare_union(node) cols = [] @@ -525,26 +530,17 @@ def prepare_select(self, node): return query def prepare_union(self, from_table): - tables = self.extract_union_list(from_table) - - table1 = tables[0] - tables_x = tables[1:] - - return table1.union(*tables_x) - - def extract_union_list(self, node): - if not (isinstance(node.left, (ast.Select, ast.Union)) and isinstance(node.right, ast.Select)): - raise NotImplementedError( - f'Unknown UNION {node.left.__class__.__name__}, {node.right.__class__.__name__}') + step1 = self.prepare_select(from_table.left) + step2 = self.prepare_select(from_table.right) - tables = [] - if isinstance(node.left, ast.Union): - tables.extend(self.extract_union_list(node.left)) + if isinstance(from_table, ast.Except): + func = sa.except_ if from_table.unique else sa.except_all + elif isinstance(from_table, ast.Intersect): + func = sa.intersect if from_table.unique else sa.intersect_all else: - tables.append(self.prepare_select(node.left)) - tables.append(self.prepare_select(node.right)) - return tables + func = sa.union if from_table.unique else sa.union_all + return func(step1, step2) def prepare_create_table(self, ast_query): columns = [] @@ -693,7 +689,7 @@ def prepare_delete(self, ast_query: ast.Delete): def get_query(self, ast_query, with_params=False): params = None - if isinstance(ast_query, ast.Select): + if isinstance(ast_query, (ast.Select, ast.Union, ast.Except, ast.Intersect)): stmt = self.prepare_select(ast_query) elif isinstance(ast_query, ast.Insert): stmt, params = self.prepare_insert(ast_query, with_params=with_params) diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 07f8abe..b7a6543 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -1026,6 +1026,40 @@ def test_case(self): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) + def test_case_simple_form(self): + sql = f'''SELECT + CASE R.DELETE_RULE + WHEN 'CASCADE' THEN 0 + WHEN 'SET NULL' THEN 2 + ELSE 3 + END AS DELETE_RULE + FROM COLLATIONS''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Case( + arg=Identifier('R.DELETE_RULE'), + rules=[ + [ + Constant('CASCADE'), + Constant(0) + ], + [ + Constant('SET NULL'), + Constant(2) + ] + ], + default=Constant(3), + alias=Identifier('DELETE_RULE') + ) + ], + from_table=Identifier('COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + def test_select_left(self): sql = f'select left(a, 1) from tab1' ast = parse_sql(sql) @@ -1152,3 +1186,23 @@ def test_table_double_quote(self): ast = parse_sql(sql) assert str(ast) == str(expected_ast) + + def test_window_function_mindsdb(self): + + # modifier + query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + partition=[Identifier('col1')], + order_by=[OrderBy(field=Identifier('col2'))], + modifier='rows BETWEEN unbounded preceding AND current row' + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_parser/test_base_sql/test_union.py b/tests/test_parser/test_base_sql/test_union.py index 1545e4b..d8d5ddd 100644 --- a/tests/test_parser/test_base_sql/test_union.py +++ b/tests/test_parser/test_base_sql/test_union.py @@ -11,40 +11,42 @@ def test_single_select_error(self): parse_sql(sql) def test_union_base(self): - sql = """SELECT col1 FROM tab1 - UNION - SELECT col1 FROM tab2""" + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} + SELECT col1 FROM tab2""" - ast = parse_sql(sql) - expected_ast = Union(unique=True, - left=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']), - ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']), - ), - ) - assert ast.to_tree() == expected_ast.to_tree() - assert str(ast) == str(expected_ast) + ast = parse_sql(sql) + expected_ast = cls(unique=True, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) def test_union_all(self): - sql = """SELECT col1 FROM tab1 - UNION ALL - SELECT col1 FROM tab2""" + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} ALL + SELECT col1 FROM tab2""" - ast = parse_sql(sql) - expected_ast = Union(unique=False, - left=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']), - ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']), - ), - ) - assert ast.to_tree() == expected_ast.to_tree() - assert str(ast) == str(expected_ast) + ast = parse_sql(sql) + expected_ast = cls(unique=False, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) - def xtest_union_alias(self): + def test_union_alias(self): sql = """SELECT * FROM ( SELECT col1 FROM tab1 UNION diff --git a/tests/test_planner/test_integration_select.py b/tests/test_planner/test_integration_select.py index 5180465..dab5e65 100644 --- a/tests/test_planner/test_integration_select.py +++ b/tests/test_planner/test_integration_select.py @@ -290,7 +290,7 @@ def test_integration_select_subquery_in_from(self): steps=[ FetchDataframeStep(integration='int', query=Select( - targets=[Identifier('column1')], + targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Select( targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Identifier('tab'), @@ -378,7 +378,7 @@ def test_integration_select_default_namespace_subquery_in_from(self): steps=[ FetchDataframeStep(integration='int', query=Select( - targets=[Identifier('column1')], + targets=[Identifier('column1', alias=Identifier('column1')),], from_table=Select( targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Identifier('tab'), @@ -554,7 +554,7 @@ def test_select_from_table_subselect_api_integration(self): plan = plan_query( query, integrations=[{'name': 'int1', 'class_type': 'api', 'type': 'data'}], - predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}] + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], ) assert plan.steps == expected_plan.steps @@ -583,6 +583,53 @@ def test_select_from_table_subselect_sql_integration(self): assert plan.steps == expected_plan.steps + def test_select_from_single_integration(self): + sql_parsed = ''' + with tab2 as ( + select * from int1.tabl2 + ) + select a from ( + select x from tab2 + union + select y from int1.tab1 + where x1 in (select id from int1.tab1) + limit 1 + ) + ''' + + sql_integration = ''' + with tab2 as ( + select * from tabl2 + ) + select a as a from ( + select x as x from tab2 + union + select y as y from tab1 + where x1 in (select id as id from tab1) + limit 1 + ) + ''' + query = parse_sql(sql_parsed, dialect='mindsdb') + + expected_plan = QueryPlan( + predictor_namespace='mindsdb', + steps=[ + FetchDataframeStep( + integration='int1', + query=parse_sql(sql_integration), + ), + ], + ) + + plan = plan_query( + query, + integrations=[{'name': 'int1', 'class_type': 'sql', 'type': 'data'}], + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], + default_namespace='mindsdb', + ) + + assert plan.steps == expected_plan.steps + def test_delete_from_table_subselect_api_integration(self): query = parse_sql(''' delete from int1.tab1