Skip to content

Commit

Permalink
bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
ea-rus committed Oct 29, 2024
1 parent bc8518c commit 008dd27
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 22 deletions.
11 changes: 4 additions & 7 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class MindsDBParser(Parser):
'drop_predictor',
'drop_datasource',
'drop_dataset',
'union',
'select',
'insert',
'update',
Expand Down Expand Up @@ -1000,14 +999,12 @@ def database_engine(self, p):
return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty}

# UNION / UNION ALL
@_('select UNION select',
'union UNION select')
def union(self, p):
@_('select UNION select')
def select(self, p):
return Union(left=p[0], right=p[2], unique=True)

@_('select UNION ALL select',
'union UNION ALL select',)
def union(self, p):
@_('select UNION ALL select')
def select(self, p):
return Union(left=p[0], right=p[3], unique=False)

# tableau
Expand Down
15 changes: 8 additions & 7 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ def plan_delete(self, query: Delete):
))

def plan_select(self, query, integration=None):
if isinstance(query, Union):
return self.plan_union(query, integration=integration)

from_table = query.from_table

if isinstance(from_table, Identifier):
Expand Down Expand Up @@ -713,13 +716,13 @@ def plan_sub_select(self, query, prev_step, add_absent_cols=False):
return sup_select
return prev_step

def plan_union(self, query):
def plan_union(self, query, integration=None):
if isinstance(query.left, Union):
step1 = self.plan_union(query.left)
step1 = self.plan_union(query.left, integration=integration)
else:
# it is select
step1 = self.plan_select(query.left)
step2 = self.plan_select(query.right)
step1 = self.plan_select(query.left, integration=integration)
step2 = self.plan_select(query.right, integration=integration)

return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique))

Expand All @@ -730,10 +733,8 @@ def from_query(self, query=None):
if query is None:
query = self.query

if isinstance(query, Select):
if isinstance(query, (Select, Union)):
self.plan_select(query)
elif isinstance(query, Union):
self.plan_union(query)
elif isinstance(query, CreateTable):
self.plan_create_table(query)
elif isinstance(query, Insert):
Expand Down
19 changes: 11 additions & 8 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def to_table(self, node):
return table

def prepare_select(self, node):
if isinstance(node, ast.Union):
return self.prepare_union(node)

cols = []
for t in node.targets:
Expand Down Expand Up @@ -454,17 +456,10 @@ def prepare_select(self, node):
full=is_full
)
elif isinstance(from_table, ast.Union):
tables = self.extract_union_list(from_table)

alias = None
if from_table.alias:
alias = self.get_alias(from_table.alias)

table1 = tables[1]
tables_x = tables[1:]

table = table1.union(*tables_x).subquery(alias)

table = self.prepare_union(from_table).subquery(alias)
query = query.select_from(table)

elif isinstance(from_table, ast.Select):
Expand Down Expand Up @@ -529,6 +524,14 @@ def prepare_select(self, node):

return query

def prepare_union(self, from_table):
tables = self.extract_union_list(from_table)

table1 = tables[0]
tables_x = tables[1:]

return table1.union(*tables_x)

def extract_union_list(self, node):
if not (isinstance(node.left, (ast.Select, ast.Union)) and isinstance(node.right, ast.Select)):
raise NotImplementedError(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_parser/test_base_sql/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,25 @@ def test_insert_from_select_no_columns(self, dialect):

assert str(ast).lower() == sql.lower()
assert ast.to_tree() == expected_ast.to_tree()

class TestInsertMDB:

def test_insert_from_union(self):
from textwrap import dedent
sql = dedent("""
INSERT INTO tbl_name(a, c) SELECT * from table1
UNION
SELECT * from table2""")[1:]

ast = parse_sql(sql)
expected_ast = Insert(
table=Identifier('tbl_name'),
columns=[Identifier('a'), Identifier('c')],
from_select=Union(
left=Select(targets=[Star()], from_table=Identifier('table1')),
right=Select(targets=[Star()], from_table=Identifier('table2'))
)
)

assert str(ast).lower() == sql.lower()
assert ast.to_tree() == expected_ast.to_tree()

0 comments on commit 008dd27

Please sign in to comment.