From e6b2356a28273802e7646236d300613284533fa5 Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 30 Oct 2024 14:55:51 +0300 Subject: [PATCH 01/20] classes for intersect, except --- mindsdb_sql/parser/ast/select/union.py | 22 ++++++++++++++++--- mindsdb_sql/parser/dialects/mindsdb/lexer.py | 4 +++- mindsdb_sql/parser/dialects/mindsdb/parser.py | 7 ++---- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mindsdb_sql/parser/ast/select/union.py b/mindsdb_sql/parser/ast/select/union.py index e78609ea..dce1da2b 100644 --- a/mindsdb_sql/parser/ast/select/union.py +++ b/mindsdb_sql/parser/ast/select/union.py @@ -2,7 +2,8 @@ from mindsdb_sql.parser.utils import indent -class Union(ASTNode): +class CombiningQuery(ASTNode): + operation = None def __init__(self, left, @@ -24,7 +25,8 @@ def to_tree(self, *args, level=0, **kwargs): left_str = f'\n{ind1}left=\n{self.left.to_tree(level=level + 2)},' right_str = f'\n{ind1}right=\n{self.right.to_tree(level=level + 2)},' - out_str = f'{ind}Union(unique={repr(self.unique)},' \ + cls_name = self.__class__.__name__ + out_str = f'{ind}{cls_name}(unique={repr(self.unique)},' \ f'{left_str}' \ f'{right_str}' \ f'\n{ind})' @@ -33,7 +35,21 @@ def to_tree(self, *args, level=0, **kwargs): def get_string(self, *args, **kwargs): left_str = str(self.left) right_str = str(self.right) - keyword = 'UNION' if self.unique else 'UNION ALL' + keyword = self.operation + if not self.unique: + keyword += ' ALL' out_str = f"""{left_str}\n{keyword}\n{right_str}""" return out_str + + +class Union(CombiningQuery): + operation = 'UNION' + + +class Intersect(CombiningQuery): + operation = 'INTERSECT' + + +class Except(CombiningQuery): + operation = 'EXCEPT' diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 01024c97..9232efbc 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -55,7 +55,7 @@ class MindsDBLexer(Lexer): JOIN, INNER, OUTER, CROSS, LEFT, RIGHT, ON, - UNION, ALL, + UNION, ALL, INTERSECT, EXCEPT, # CASE CASE, ELSE, END, THEN, WHEN, @@ -238,6 +238,8 @@ class MindsDBLexer(Lexer): # UNION UNION = r'\bUNION\b' + INTERSECT = r'\bINTERSECT\b' + EXCEPT = r'\bEXCEPT\b' ALL = r'\bALL\b' # CASE diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index c0d1edde..0879628e 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -998,14 +998,11 @@ def database_engine(self, p): engine = p.string return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty} - # UNION / UNION ALL + # Combining @_('select UNION select') - def select(self, p): - return Union(left=p[0], right=p[2], unique=True) - @_('select UNION ALL select') def select(self, p): - return Union(left=p[0], right=p[3], unique=False) + return Union(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL')) # tableau @_('LPAREN select RPAREN') From fe2730801cb8610877ac7cc5bcc67f2b99b4ef9a Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 30 Oct 2024 15:29:04 +0300 Subject: [PATCH 02/20] intersect, except --- mindsdb_sql/parser/ast/select/__init__.py | 2 +- mindsdb_sql/parser/dialects/mindsdb/parser.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mindsdb_sql/parser/ast/select/__init__.py b/mindsdb_sql/parser/ast/select/__init__.py index d064242c..7dd64b61 100644 --- a/mindsdb_sql/parser/ast/select/__init__.py +++ b/mindsdb_sql/parser/ast/select/__init__.py @@ -1,6 +1,6 @@ from .select import Select from .common_table_expression import CommonTableExpression -from .union import Union +from .union import Union, Except, Intersect from .constant import Constant, NullConstant, Last from .star import Star from .identifier import Identifier diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 0879628e..3733f422 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1004,6 +1004,16 @@ def database_engine(self, p): def select(self, p): return Union(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL')) + @_('select INTERSECT select') + @_('select INTERSECT ALL select') + def select(self, p): + return Intersect(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL')) + + @_('select EXCEPT select') + @_('select EXCEPT ALL select') + def select(self, p): + return Except(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL')) + # tableau @_('LPAREN select RPAREN') def select(self, p): From d79cf771b4653413d9288cb62e3e7a4637ef9a94 Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 30 Oct 2024 15:34:33 +0300 Subject: [PATCH 03/20] tests for intersect, except --- tests/test_parser/test_base_sql/test_union.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/tests/test_parser/test_base_sql/test_union.py b/tests/test_parser/test_base_sql/test_union.py index 1545e4b0..cc4fe775 100644 --- a/tests/test_parser/test_base_sql/test_union.py +++ b/tests/test_parser/test_base_sql/test_union.py @@ -11,38 +11,40 @@ def test_single_select_error(self): parse_sql(sql) def test_union_base(self): - sql = """SELECT col1 FROM tab1 - UNION - SELECT col1 FROM tab2""" + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} + SELECT col1 FROM tab2""" - ast = parse_sql(sql) - expected_ast = Union(unique=True, - left=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']), - ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']), - ), - ) - assert ast.to_tree() == expected_ast.to_tree() - assert str(ast) == str(expected_ast) + ast = parse_sql(sql) + expected_ast = cls(unique=True, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) def test_union_all(self): - sql = """SELECT col1 FROM tab1 - UNION ALL - SELECT col1 FROM tab2""" + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} ALL + SELECT col1 FROM tab2""" - ast = parse_sql(sql) - expected_ast = Union(unique=False, - left=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']), - ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']), - ), - ) - assert ast.to_tree() == expected_ast.to_tree() - assert str(ast) == str(expected_ast) + ast = parse_sql(sql) + expected_ast = cls(unique=False, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) def xtest_union_alias(self): sql = """SELECT * FROM ( From d5eb2da16ddd466dbb8e279e37a15eae8f898bef Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 30 Oct 2024 15:34:46 +0300 Subject: [PATCH 04/20] enabled test --- tests/test_parser/test_base_sql/test_union.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_parser/test_base_sql/test_union.py b/tests/test_parser/test_base_sql/test_union.py index cc4fe775..bc49f2a2 100644 --- a/tests/test_parser/test_base_sql/test_union.py +++ b/tests/test_parser/test_base_sql/test_union.py @@ -46,7 +46,7 @@ def test_union_all(self): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - def xtest_union_alias(self): + def test_union_alias(self): sql = """SELECT * FROM ( SELECT col1 FROM tab1 UNION @@ -60,16 +60,17 @@ def xtest_union_alias(self): from_table=Union( unique=True, alias=Identifier('alias'), - left=Union( + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']),), + right=Union( unique=True, left=Select( targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']),), + from_table=Identifier(parts=['tab2']),), right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']),), + from_table=Identifier(parts=['tab3']),), ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab3']),), + ) ) assert ast.to_tree() == expected_ast.to_tree() From e01c09494e01d4ebfb7d9f71b365db7baaa6e3de Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 6 Nov 2024 15:18:59 +0300 Subject: [PATCH 05/20] fix standard render 'interval': number have to be quoted --- mindsdb_sql/parser/ast/select/operation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mindsdb_sql/parser/ast/select/operation.py b/mindsdb_sql/parser/ast/select/operation.py index d7cad6a6..1a3ca23c 100644 --- a/mindsdb_sql/parser/ast/select/operation.py +++ b/mindsdb_sql/parser/ast/select/operation.py @@ -177,7 +177,12 @@ def __init__(self, info): super().__init__(op='interval', args=[info, ]) def get_string(self, *args, **kwargs): - return f'INTERVAL {self.args[0]}' + + arg = self.args[0] + items = arg.split(' ', maxsplit=1) + # quote first element + items[0] = f"'{items[0]}'" + return "INTERVAL " + " ".join(items) def to_tree(self, *args, level=0, **kwargs): return self.get_string( *args, **kwargs) From 63a206e3934c4d5a2c0cf698debf5b6a061bc7e4 Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 6 Nov 2024 15:19:25 +0300 Subject: [PATCH 06/20] fix case with default value for sqlachemy render --- mindsdb_sql/render/sqlalchemy_render.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index bcd54f4e..01db44c5 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -293,10 +293,11 @@ def prepare_case(self, t: ast.Case): conditions.append( (self.to_expression(condition), self.to_expression(result)) ) + default = None if t.default is not None: - conditions.append(self.to_expression(t.default)) + default = self.to_expression(t.default) - return sa.case(*conditions) + return sa.case(*conditions, else_=default) def to_function(self, t): op = getattr(sa.func, t.op) From c34fefc08c28d7b0d1ce4afa4d2167a6c0a86eb7 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 8 Nov 2024 16:58:04 +0300 Subject: [PATCH 07/20] planning render combining queries --- mindsdb_sql/parser/dialects/mindsdb/parser.py | 44 ++++++++++++------- mindsdb_sql/planner/query_planner.py | 19 ++++---- mindsdb_sql/planner/steps.py | 3 +- mindsdb_sql/render/sqlalchemy_render.py | 29 +++++------- tests/test_parser/test_base_sql/test_union.py | 11 +++-- 5 files changed, 55 insertions(+), 51 deletions(-) diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index d741649c..abe2fc54 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -70,6 +70,7 @@ class MindsDBParser(Parser): 'drop_dataset', 'select', 'insert', + 'union', 'update', 'delete', 'evaluate', @@ -614,10 +615,13 @@ def update(self, p): # INSERT @_('INSERT INTO identifier LPAREN column_list RPAREN select', - 'INSERT INTO identifier select') + 'INSERT INTO identifier LPAREN column_list RPAREN union', + 'INSERT INTO identifier select', + 'INSERT INTO identifier union') def insert(self, p): columns = getattr(p, 'column_list', None) - return Insert(table=p.identifier, columns=columns, from_select=p.select) + query = p.select if hasattr(p, 'select') else p.union + return Insert(table=p.identifier, columns=columns, from_select=query) @_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set', 'INSERT INTO identifier VALUES expr_list_set') @@ -999,20 +1003,28 @@ def database_engine(self, p): return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty} # Combining - @_('select UNION select') - @_('select UNION ALL select') - def select(self, p): - return Union(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL')) - - @_('select INTERSECT select') - @_('select INTERSECT ALL select') - def select(self, p): - return Intersect(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL')) - - @_('select EXCEPT select') - @_('select EXCEPT ALL select') - def select(self, p): - return Except(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL')) + @_('select UNION select', + 'union UNION select', + 'select UNION ALL select', + 'union UNION ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Union(left=p[0], right=p[2] if unique else p[3], unique=unique) + + @_('select INTERSECT select', + 'union INTERSECT select', + 'select INTERSECT ALL select', + 'union INTERSECT ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Intersect(left=p[0], right=p[2] if unique else p[3], unique=unique) + @_('select EXCEPT select', + 'union EXCEPT select', + 'select EXCEPT ALL select', + 'union EXCEPT ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Except(left=p[0], right=p[2] if unique else p[3], unique=unique) # tableau @_('LPAREN select RPAREN') diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index f2d08d16..05910ee5 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -3,7 +3,7 @@ from mindsdb_sql.exceptions import PlanningException from mindsdb_sql.parser import ast from mindsdb_sql.parser.ast import (Select, Identifier, Join, Star, BinaryOperation, Constant, Union, CreateTable, - Function, Insert, + Function, Insert, Except, Intersect, Update, NativeQuery, Parameter, Delete) from mindsdb_sql.planner import utils from mindsdb_sql.planner.query_plan import QueryPlan @@ -678,7 +678,7 @@ def plan_cte(self, query): self.cte_results[name] = step.result def plan_select(self, query, integration=None): - if isinstance(query, Union): + if isinstance(query, (Union, Except, Intersect)): return self.plan_union(query, integration=integration) if query.cte is not None: @@ -734,14 +734,15 @@ def plan_sub_select(self, query, prev_step, add_absent_cols=False): return prev_step def plan_union(self, query, integration=None): - if isinstance(query.left, Union): - step1 = self.plan_union(query.left, integration=integration) - else: - # it is select - step1 = self.plan_select(query.left, integration=integration) + step1 = self.plan_select(query.left, integration=integration) step2 = self.plan_select(query.right, integration=integration) + operation = 'union' + if isinstance(query, Except): + operation = 'except' + elif isinstance(query, Intersect): + operation = 'intersect' - return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique)) + return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique, operation=operation)) # method for compatibility def from_query(self, query=None): @@ -750,7 +751,7 @@ def from_query(self, query=None): if query is None: query = self.query - if isinstance(query, (Select, Union)): + if isinstance(query, (Select, Union, Except, Intersect)): self.plan_select(query) elif isinstance(query, CreateTable): self.plan_create_table(query) diff --git a/mindsdb_sql/planner/steps.py b/mindsdb_sql/planner/steps.py index 40e86e53..395f3079 100644 --- a/mindsdb_sql/planner/steps.py +++ b/mindsdb_sql/planner/steps.py @@ -75,11 +75,12 @@ def __init__(self, left, right, query, *args, **kwargs): class UnionStep(PlanStep): """Union of two dataframes, producing a new dataframe""" - def __init__(self, left, right, unique, *args, **kwargs): + def __init__(self, left, right, unique, operation='union', *args, **kwargs): super().__init__(*args, **kwargs) self.left = left self.right = right self.unique = unique + self.operation = operation # TODO remove diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index bcd54f4e..745a8a92 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -396,7 +396,7 @@ def to_table(self, node): return table def prepare_select(self, node): - if isinstance(node, ast.Union): + if isinstance(node, (ast.Union, ast.Except, ast.Intersect)): return self.prepare_union(node) cols = [] @@ -525,26 +525,17 @@ def prepare_select(self, node): return query def prepare_union(self, from_table): - tables = self.extract_union_list(from_table) + step1 = self.prepare_select(from_table.left) + step2 = self.prepare_select(from_table.right) - table1 = tables[0] - tables_x = tables[1:] - - return table1.union(*tables_x) - - def extract_union_list(self, node): - if not (isinstance(node.left, (ast.Select, ast.Union)) and isinstance(node.right, ast.Select)): - raise NotImplementedError( - f'Unknown UNION {node.left.__class__.__name__}, {node.right.__class__.__name__}') - - tables = [] - if isinstance(node.left, ast.Union): - tables.extend(self.extract_union_list(node.left)) + if isinstance(from_table, ast.Except): + func = sa.except_ if from_table.unique else sa.except_all + elif isinstance(from_table, ast.Intersect): + func = sa.intersect if from_table.unique else sa.intersect_all else: - tables.append(self.prepare_select(node.left)) - tables.append(self.prepare_select(node.right)) - return tables + func = sa.union if from_table.unique else sa.union_all + return func(step1, step2) def prepare_create_table(self, ast_query): columns = [] @@ -693,7 +684,7 @@ def prepare_delete(self, ast_query: ast.Delete): def get_query(self, ast_query, with_params=False): params = None - if isinstance(ast_query, ast.Select): + if isinstance(ast_query, (ast.Select, ast.Union, ast.Except, ast.Intersect)): stmt = self.prepare_select(ast_query) elif isinstance(ast_query, ast.Insert): stmt, params = self.prepare_insert(ast_query, with_params=with_params) diff --git a/tests/test_parser/test_base_sql/test_union.py b/tests/test_parser/test_base_sql/test_union.py index bc49f2a2..d8d5dddf 100644 --- a/tests/test_parser/test_base_sql/test_union.py +++ b/tests/test_parser/test_base_sql/test_union.py @@ -60,17 +60,16 @@ def test_union_alias(self): from_table=Union( unique=True, alias=Identifier('alias'), - left=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']),), - right=Union( + left=Union( unique=True, left=Select( targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']),), + from_table=Identifier(parts=['tab1']),), right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab3']),), + from_table=Identifier(parts=['tab2']),), ), - + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab3']),), ) ) assert ast.to_tree() == expected_ast.to_tree() From 6c93ee08e02714ec6f39d9c024179b9f6a795c93 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 8 Nov 2024 17:19:43 +0300 Subject: [PATCH 08/20] fix parsing (.. union ..) union .. --- mindsdb_sql/parser/dialects/mindsdb/parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index abe2fc54..523e30a0 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1028,8 +1028,9 @@ def union(self, p): # tableau @_('LPAREN select RPAREN') + @_('LPAREN union RPAREN') def select(self, p): - return p.select + return p[1] # WITH @_('ctes select') From 70146e8fb2d36849898d8c368176b2e4f8cd3e42 Mon Sep 17 00:00:00 2001 From: andrew Date: Sun, 10 Nov 2024 22:44:32 +0300 Subject: [PATCH 09/20] cte with union --- mindsdb_sql/parser/dialects/mindsdb/parser.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 523e30a0..5186024c 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1050,13 +1050,14 @@ def ctes(self, p): ] return ctes - @_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN') + @_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN', + 'WITH identifier cte_columns_or_nothing AS LPAREN union RPAREN') def ctes(self, p): return [ CommonTableExpression( name=p.identifier, columns=p.cte_columns_or_nothing, - query=p.select) + query=p[5]) ] @_('empty') From fbb913d0d8d3a284a017c685a2fdbd9be9d3cc26 Mon Sep 17 00:00:00 2001 From: andrew Date: Sun, 10 Nov 2024 22:44:58 +0300 Subject: [PATCH 10/20] combining queries traversal --- mindsdb_sql/planner/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb_sql/planner/utils.py b/mindsdb_sql/planner/utils.py index e4d52787..9baa16ec 100644 --- a/mindsdb_sql/planner/utils.py +++ b/mindsdb_sql/planner/utils.py @@ -145,7 +145,7 @@ def query_traversal(node, callback, is_table=False, is_target=False, parent_quer array.append(node_out) node.order_by = array - elif isinstance(node, ast.Union): + elif isinstance(node, (ast.Union, ast.Intersect, ast.Except)): node_out = query_traversal(node.left, callback, parent_query=node) if node_out is not None: node.left = node_out From 0db06c6e502f4ff82ae9bfa54069074c1ffc7a77 Mon Sep 17 00:00:00 2001 From: andrew Date: Sun, 10 Nov 2024 22:49:45 +0300 Subject: [PATCH 11/20] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED in window function --- mindsdb_sql/parser/ast/select/operation.py | 6 ++++-- mindsdb_sql/parser/dialects/mindsdb/parser.py | 9 +++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mindsdb_sql/parser/ast/select/operation.py b/mindsdb_sql/parser/ast/select/operation.py index d7cad6a6..f844aa79 100644 --- a/mindsdb_sql/parser/ast/select/operation.py +++ b/mindsdb_sql/parser/ast/select/operation.py @@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs): class WindowFunction(ASTNode): - def __init__(self, function, partition=None, order_by=None, alias=None): + def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None): super().__init__() self.function = function self.partition = partition self.order_by = order_by self.alias = alias + self.modifier = modifier def to_tree(self, *args, level=0, **kwargs): fnc_str = self.function.to_tree(level=level+2) @@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs): alias_str = self.alias.to_string() else: alias_str = '' - return f'{fnc_str} over({partition_str} {order_str}) {alias_str}' + modifier_str = self.modifier if self.modifier else '' + return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}' class Object(ASTNode): diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 5186024c..a5ca39ce 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1364,13 +1364,18 @@ def case_condition(self, p): return [p.expr0, p.expr1] # Window function - @_('function OVER LPAREN window RPAREN') + @_('expr OVER LPAREN window RPAREN', + 'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN') def window_function(self, p): + modifier = None + if hasattr(p, 'BETWEEN'): + modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}' return WindowFunction( - function=p.function, + function=p.expr, order_by=p.window.get('order_by'), partition=p.window.get('partition'), + modifier=modifier, ) @_('window PARTITION_BY expr_list') From 8ac9fe2ae1f6f5045ee9d64c41f926f1159819c9 Mon Sep 17 00:00:00 2001 From: andrew Date: Sun, 10 Nov 2024 22:51:00 +0300 Subject: [PATCH 12/20] simple form of case --- mindsdb_sql/parser/ast/select/case.py | 14 ++++++++++++-- mindsdb_sql/parser/dialects/mindsdb/parser.py | 9 +++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mindsdb_sql/parser/ast/select/case.py b/mindsdb_sql/parser/ast/select/case.py index 02fa8275..1ae0df06 100644 --- a/mindsdb_sql/parser/ast/select/case.py +++ b/mindsdb_sql/parser/ast/select/case.py @@ -4,13 +4,14 @@ class Case(ASTNode): - def __init__(self, rules, default=None, *args, **kwargs): + def __init__(self, rules, default=None, arg=None, *args, **kwargs): super().__init__(*args, **kwargs) # structure: # [ # [ condition, result ] # ] + self.arg = arg self.rules = rules self.default = default @@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs): if self.default is not None: default_str = f'{ind1}default => {self.default.to_string()}\n' + arg_str = '' + if self.arg is not None: + arg_str = f'{ind1}arg => {self.arg.to_string()}\n' + return f'{ind}Case(\n' \ + f'{arg_str}'\ f'{rules_str}\n' \ f'{default_str}' \ f'{ind})' @@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs): default_str = '' if self.default is not None: default_str = f' ELSE {self.default.to_string()}' - return f"CASE {rules_str}{default_str} END" + + arg_str = '' + if self.arg is not None: + arg_str = f'{self.arg.to_string()} ' + return f"CASE {arg_str}{rules_str}{default_str} END" diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index a5ca39ce..ae139597 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1352,6 +1352,15 @@ def column_list(self, p): def case(self, p): return Case(rules=p.case_conditions, default=getattr(p, 'expr', None)) + @_('CASE expr case_conditions ELSE expr END', + 'CASE expr case_conditions END') + def case(self, p): + if hasattr(p, 'expr'): + arg, default = p.expr, None + else: + arg, default = p.expr0, p.expr1 + return Case(rules=p.case_conditions, default=default, arg=arg) + @_('case_condition', 'case_conditions case_condition') def case_conditions(self, p): From b65132df7a7553722bf99ddddd1f7eaa2f5191d1 Mon Sep 17 00:00:00 2001 From: andrew Date: Sun, 10 Nov 2024 22:51:32 +0300 Subject: [PATCH 13/20] query with cte to one integration --- mindsdb_sql/planner/query_planner.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index 05910ee5..ff25b4fd 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -229,6 +229,19 @@ def find_objects(node, is_table, **kwargs): mdb_entities.append(node) query_traversal(query, find_objects) + + # cte names are not mdb objects + if query.cte: + cte_names = [ + cte.name.parts[-1] + for cte in query.cte + ] + mdb_entities = [ + item + for item in mdb_entities + if '.'.join(item.parts) not in cte_names + ] + return { 'mdb_entities': mdb_entities, 'integrations': integrations, @@ -672,6 +685,16 @@ def plan_delete(self, query: Delete): )) def plan_cte(self, query): + query_info = self.get_query_info(query) + + if ( + len(query_info['integrations']) == 1 + and len(query_info['mdb_entities']) == 0 + and len(query_info['user_functions']) == 0 + ): + # single integration, will be planned later + return + for cte in query.cte: step = self.plan_select(cte.query) name = cte.name.parts[-1] From 4298b66fd86758de4fb2177c8a7f98cb9b625d85 Mon Sep 17 00:00:00 2001 From: andrew Date: Sun, 10 Nov 2024 23:01:40 +0300 Subject: [PATCH 14/20] fix case render --- mindsdb_sql/render/sqlalchemy_render.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index 7a30aea9..929f262f 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -297,7 +297,11 @@ def prepare_case(self, t: ast.Case): if t.default is not None: default = self.to_expression(t.default) - return sa.case(*conditions, else_=default) + value = None + if t.arg is not None: + value = self.to_expression(t.arg) + + return sa.case(*conditions, else_=default, value=value) def to_function(self, t): op = getattr(sa.func, t.op) From ad6ee9d1cb6e447023649422a17f1b1fd6730f11 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 11 Nov 2024 10:15:53 +0300 Subject: [PATCH 15/20] test_case_simple_form --- .../test_base_sql/test_select_structure.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 07f8abee..0320bd3d 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -1026,6 +1026,40 @@ def test_case(self): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) + def test_case_simple_form(self): + sql = f'''SELECT + CASE R.DELETE_RULE + WHEN 'CASCADE' THEN 0 + WHEN 'SET NULL' THEN 2 + ELSE 3 + END AS DELETE_RULE + FROM COLLATIONS''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Case( + arg=Identifier('R.DELETE_RULE'), + rules=[ + [ + Constant('CASCADE'), + Constant(0) + ], + [ + Constant('SET NULL'), + Constant(2) + ] + ], + default=Constant(3), + alias=Identifier('DELETE_RULE') + ) + ], + from_table=Identifier('COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + def test_select_left(self): sql = f'select left(a, 1) from tab1' ast = parse_sql(sql) From 2fb6ce4681b0f0f2544933515da55f3529dfd514 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 11 Nov 2024 10:21:27 +0300 Subject: [PATCH 16/20] window function test --- .../test_base_sql/test_select_structure.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 0320bd3d..b7a6543c 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -1186,3 +1186,23 @@ def test_table_double_quote(self): ast = parse_sql(sql) assert str(ast) == str(expected_ast) + + def test_window_function_mindsdb(self): + + # modifier + query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + partition=[Identifier('col1')], + order_by=[OrderBy(field=Identifier('col2'))], + modifier='rows BETWEEN unbounded preceding AND current row' + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + From 5e594fbd9f899e771c9bd3c7c1dcc1c101887900 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 11 Nov 2024 11:01:38 +0300 Subject: [PATCH 17/20] test single_integration --- tests/test_planner/test_integration_select.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/test_planner/test_integration_select.py b/tests/test_planner/test_integration_select.py index 5180465d..faa88315 100644 --- a/tests/test_planner/test_integration_select.py +++ b/tests/test_planner/test_integration_select.py @@ -554,7 +554,7 @@ def test_select_from_table_subselect_api_integration(self): plan = plan_query( query, integrations=[{'name': 'int1', 'class_type': 'api', 'type': 'data'}], - predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}] + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], ) assert plan.steps == expected_plan.steps @@ -583,6 +583,47 @@ def test_select_from_table_subselect_sql_integration(self): assert plan.steps == expected_plan.steps + def test_select_from_single_integration(self): + sql_parsed = ''' + with tab2 as ( + select * from int1.tabl2 + ) + select x from tab2 + join int1.tab1 on 0=0 + where x1 in (select id from int1.tab1) + limit 1 + ''' + + sql_integration = ''' + with tab2 as ( + select * from tabl2 + ) + select x from tab2 + join tab1 on 0=0 + where x1 in (select id as id from tab1) + limit 1 + ''' + query = parse_sql(sql_parsed, dialect='mindsdb') + + expected_plan = QueryPlan( + predictor_namespace='mindsdb', + steps=[ + FetchDataframeStep( + integration='int1', + query=parse_sql(sql_integration), + ), + ], + ) + + plan = plan_query( + query, + integrations=[{'name': 'int1', 'class_type': 'sql', 'type': 'data'}], + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], + default_namespace='mindsdb', + ) + + assert plan.steps == expected_plan.steps + def test_delete_from_table_subselect_api_integration(self): query = parse_sql(''' delete from int1.tab1 From fce61dc8404bd592d9b6c6e73c3e0c3d29997593 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 11 Nov 2024 11:08:30 +0300 Subject: [PATCH 18/20] fix window function --- mindsdb_sql/parser/ast/select/operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb_sql/parser/ast/select/operation.py b/mindsdb_sql/parser/ast/select/operation.py index 9f8d95c0..e208d435 100644 --- a/mindsdb_sql/parser/ast/select/operation.py +++ b/mindsdb_sql/parser/ast/select/operation.py @@ -144,7 +144,7 @@ def to_string(self, *args, **kwargs): alias_str = self.alias.to_string() else: alias_str = '' - modifier_str = self.modifier if self.modifier else '' + modifier_str = ' ' + self.modifier if self.modifier else '' return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}' From c6983a24ee0e56e7362eecd9c9f91579a97b11e2 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 11 Nov 2024 12:25:36 +0300 Subject: [PATCH 19/20] plan union in single integration query --- mindsdb_sql/planner/query_planner.py | 98 +++++++++++-------- tests/test_planner/test_integration_select.py | 26 +++-- 2 files changed, 74 insertions(+), 50 deletions(-) diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index ff25b4fd..35bea81c 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -231,7 +231,7 @@ def find_objects(node, is_table, **kwargs): query_traversal(query, find_objects) # cte names are not mdb objects - if query.cte: + if isinstance(query, Select) and query.cte: cte_names = [ cte.name.parts[-1] for cte in query.cte @@ -271,21 +271,21 @@ def find_selects(node, **kwargs): return find_selects def plan_select_identifier(self, query): - query_info = self.get_query_info(query) - - if len(query_info['integrations']) == 0 and len(query_info['predictors']) >= 1: - # select from predictor - return self.plan_select_from_predictor(query) - elif ( - len(query_info['integrations']) == 1 - and len(query_info['mdb_entities']) == 0 - and len(query_info['user_functions']) == 0 - ): - - int_name = list(query_info['integrations'])[0] - if self.integrations.get(int_name, {}).get('class_type') != 'api': - # one integration without predictors, send all query to integration - return self.plan_integration_select(query) + # query_info = self.get_query_info(query) + # + # if len(query_info['integrations']) == 0 and len(query_info['predictors']) >= 1: + # # select from predictor + # return self.plan_select_from_predictor(query) + # elif ( + # len(query_info['integrations']) == 1 + # and len(query_info['mdb_entities']) == 0 + # and len(query_info['user_functions']) == 0 + # ): + # + # int_name = list(query_info['integrations'])[0] + # if self.integrations.get(int_name, {}).get('class_type') != 'api': + # # one integration without predictors, send all query to integration + # return self.plan_integration_select(query) # find subselects main_integration, _ = self.resolve_database_table(query.from_table) @@ -380,21 +380,21 @@ def plan_api_db_select(self, query): def plan_nested_select(self, select): - query_info = self.get_query_info(select) - # get all predictors - - if ( - len(query_info['mdb_entities']) == 0 - and len(query_info['integrations']) == 1 - and len(query_info['user_functions']) == 0 - and 'files' not in query_info['integrations'] - and 'views' not in query_info['integrations'] - ): - int_name = list(query_info['integrations'])[0] - if self.integrations.get(int_name, {}).get('class_type') != 'api': - - # if no predictor inside = run as is - return self.plan_integration_nested_select(select, int_name) + # query_info = self.get_query_info(select) + # # get all predictors + # + # if ( + # len(query_info['mdb_entities']) == 0 + # and len(query_info['integrations']) == 1 + # and len(query_info['user_functions']) == 0 + # and 'files' not in query_info['integrations'] + # and 'views' not in query_info['integrations'] + # ): + # int_name = list(query_info['integrations'])[0] + # if self.integrations.get(int_name, {}).get('class_type') != 'api': + # + # # if no predictor inside = run as is + # return self.plan_integration_nested_select(select, int_name) return self.plan_mdb_nested_select(select) @@ -685,22 +685,38 @@ def plan_delete(self, query: Delete): )) def plan_cte(self, query): - query_info = self.get_query_info(query) - - if ( - len(query_info['integrations']) == 1 - and len(query_info['mdb_entities']) == 0 - and len(query_info['user_functions']) == 0 - ): - # single integration, will be planned later - return for cte in query.cte: step = self.plan_select(cte.query) name = cte.name.parts[-1] self.cte_results[name] = step.result + def check_single_integration(self, query): + query_info = self.get_query_info(query) + + # can we send all query to integration? + + # one integration and not mindsdb objects in query + if ( + len(query_info['mdb_entities']) == 0 + and len(query_info['integrations']) == 1 + and 'files' not in query_info['integrations'] + and 'views' not in query_info['integrations'] + and len(query_info['user_functions']) == 0 + ): + + int_name = list(query_info['integrations'])[0] + # if is sql database + if self.integrations.get(int_name, {}).get('class_type') != 'api': + + # send to this integration + self.prepare_integration_select(int_name, query) + + last_step = self.plan.add_step(FetchDataframeStep(integration=int_name, query=query)) + return last_step + def plan_select(self, query, integration=None): + if isinstance(query, (Union, Except, Intersect)): return self.plan_union(query, integration=integration) @@ -775,6 +791,8 @@ def from_query(self, query=None): query = self.query if isinstance(query, (Select, Union, Except, Intersect)): + if self.check_single_integration(query): + return self.plan self.plan_select(query) elif isinstance(query, CreateTable): self.plan_create_table(query) diff --git a/tests/test_planner/test_integration_select.py b/tests/test_planner/test_integration_select.py index faa88315..dab5e65a 100644 --- a/tests/test_planner/test_integration_select.py +++ b/tests/test_planner/test_integration_select.py @@ -290,7 +290,7 @@ def test_integration_select_subquery_in_from(self): steps=[ FetchDataframeStep(integration='int', query=Select( - targets=[Identifier('column1')], + targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Select( targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Identifier('tab'), @@ -378,7 +378,7 @@ def test_integration_select_default_namespace_subquery_in_from(self): steps=[ FetchDataframeStep(integration='int', query=Select( - targets=[Identifier('column1')], + targets=[Identifier('column1', alias=Identifier('column1')),], from_table=Select( targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Identifier('tab'), @@ -588,20 +588,26 @@ def test_select_from_single_integration(self): with tab2 as ( select * from int1.tabl2 ) - select x from tab2 - join int1.tab1 on 0=0 - where x1 in (select id from int1.tab1) - limit 1 + select a from ( + select x from tab2 + union + select y from int1.tab1 + where x1 in (select id from int1.tab1) + limit 1 + ) ''' sql_integration = ''' with tab2 as ( select * from tabl2 ) - select x from tab2 - join tab1 on 0=0 - where x1 in (select id as id from tab1) - limit 1 + select a as a from ( + select x as x from tab2 + union + select y as y from tab1 + where x1 in (select id as id from tab1) + limit 1 + ) ''' query = parse_sql(sql_parsed, dialect='mindsdb') From 64622e9db124272f79bf7412f86f4ef0684bdc27 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 11 Nov 2024 14:18:16 +0300 Subject: [PATCH 20/20] fix render select from union --- mindsdb_sql/render/sqlalchemy_render.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index 929f262f..c4c31aaf 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -387,7 +387,7 @@ def to_table(self, node): if node.alias: table = aliased(table, name=self.get_alias(node.alias)) - elif isinstance(node, ast.Select): + elif isinstance(node, (ast.Select, ast.Union, ast.Intersect, ast.Except)): sub_stmt = self.prepare_select(node) alias = None if node.alias: