From 008dd27955e69436de9ab7339ae9d05986f9d026 Mon Sep 17 00:00:00 2001 From: andrew Date: Tue, 29 Oct 2024 19:04:46 +0300 Subject: [PATCH] bump version --- mindsdb_sql/parser/dialects/mindsdb/parser.py | 11 ++++------ mindsdb_sql/planner/query_planner.py | 15 +++++++------ mindsdb_sql/render/sqlalchemy_render.py | 19 +++++++++------- .../test_parser/test_base_sql/test_insert.py | 22 +++++++++++++++++++ 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 05233429..c84ac9b6 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -68,7 +68,6 @@ class MindsDBParser(Parser): 'drop_predictor', 'drop_datasource', 'drop_dataset', - 'union', 'select', 'insert', 'update', @@ -1000,14 +999,12 @@ def database_engine(self, p): return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty} # UNION / UNION ALL - @_('select UNION select', - 'union UNION select') - def union(self, p): + @_('select UNION select') + def select(self, p): return Union(left=p[0], right=p[2], unique=True) - @_('select UNION ALL select', - 'union UNION ALL select',) - def union(self, p): + @_('select UNION ALL select') + def select(self, p): return Union(left=p[0], right=p[3], unique=False) # tableau diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index e6d58be2..90b697b9 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -664,6 +664,9 @@ def plan_delete(self, query: Delete): )) def plan_select(self, query, integration=None): + if isinstance(query, Union): + return self.plan_union(query, integration=integration) + from_table = query.from_table if isinstance(from_table, Identifier): @@ -713,13 +716,13 @@ def plan_sub_select(self, query, prev_step, add_absent_cols=False): return sup_select return prev_step - def plan_union(self, query): + def plan_union(self, query, integration=None): if isinstance(query.left, Union): - step1 = self.plan_union(query.left) + step1 = self.plan_union(query.left, integration=integration) else: # it is select - step1 = self.plan_select(query.left) - step2 = self.plan_select(query.right) + step1 = self.plan_select(query.left, integration=integration) + step2 = self.plan_select(query.right, integration=integration) return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique)) @@ -730,10 +733,8 @@ def from_query(self, query=None): if query is None: query = self.query - if isinstance(query, Select): + if isinstance(query, (Select, Union)): self.plan_select(query) - elif isinstance(query, Union): - self.plan_union(query) elif isinstance(query, CreateTable): self.plan_create_table(query) elif isinstance(query, Insert): diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index caa9e232..4d7e513d 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -396,6 +396,8 @@ def to_table(self, node): return table def prepare_select(self, node): + if isinstance(node, ast.Union): + return self.prepare_union(node) cols = [] for t in node.targets: @@ -454,17 +456,10 @@ def prepare_select(self, node): full=is_full ) elif isinstance(from_table, ast.Union): - tables = self.extract_union_list(from_table) - alias = None if from_table.alias: alias = self.get_alias(from_table.alias) - - table1 = tables[1] - tables_x = tables[1:] - - table = table1.union(*tables_x).subquery(alias) - + table = self.prepare_union(from_table).subquery(alias) query = query.select_from(table) elif isinstance(from_table, ast.Select): @@ -529,6 +524,14 @@ def prepare_select(self, node): return query + def prepare_union(self, from_table): + tables = self.extract_union_list(from_table) + + table1 = tables[0] + tables_x = tables[1:] + + return table1.union(*tables_x) + def extract_union_list(self, node): if not (isinstance(node.left, (ast.Select, ast.Union)) and isinstance(node.right, ast.Select)): raise NotImplementedError( diff --git a/tests/test_parser/test_base_sql/test_insert.py b/tests/test_parser/test_base_sql/test_insert.py index 1318b071..ed69a4b3 100644 --- a/tests/test_parser/test_base_sql/test_insert.py +++ b/tests/test_parser/test_base_sql/test_insert.py @@ -74,3 +74,25 @@ def test_insert_from_select_no_columns(self, dialect): assert str(ast).lower() == sql.lower() assert ast.to_tree() == expected_ast.to_tree() + +class TestInsertMDB: + + def test_insert_from_union(self): + from textwrap import dedent + sql = dedent(""" + INSERT INTO tbl_name(a, c) SELECT * from table1 + UNION + SELECT * from table2""")[1:] + + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier('tbl_name'), + columns=[Identifier('a'), Identifier('c')], + from_select=Union( + left=Select(targets=[Star()], from_table=Identifier('table1')), + right=Select(targets=[Star()], from_table=Identifier('table2')) + ) + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() \ No newline at end of file