Skip to content

Commit

Permalink
Merge pull request #411 from mindsdb/combining-queries
Browse files Browse the repository at this point in the history
Support combining queries: intersect, except
  • Loading branch information
ea-rus authored Nov 11, 2024
2 parents 57ba416 + 64622e9 commit 410c2f9
Show file tree
Hide file tree
Showing 13 changed files with 332 additions and 120 deletions.
2 changes: 1 addition & 1 deletion mindsdb_sql/parser/ast/select/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .select import Select
from .common_table_expression import CommonTableExpression
from .union import Union
from .union import Union, Except, Intersect
from .constant import Constant, NullConstant, Last
from .star import Star
from .identifier import Identifier
Expand Down
14 changes: 12 additions & 2 deletions mindsdb_sql/parser/ast/select/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@


class Case(ASTNode):
def __init__(self, rules, default=None, *args, **kwargs):
def __init__(self, rules, default=None, arg=None, *args, **kwargs):
super().__init__(*args, **kwargs)

# structure:
# [
# [ condition, result ]
# ]
self.arg = arg
self.rules = rules
self.default = default

Expand All @@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs):
if self.default is not None:
default_str = f'{ind1}default => {self.default.to_string()}\n'

arg_str = ''
if self.arg is not None:
arg_str = f'{ind1}arg => {self.arg.to_string()}\n'

return f'{ind}Case(\n' \
f'{arg_str}'\
f'{rules_str}\n' \
f'{default_str}' \
f'{ind})'
Expand All @@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs):
default_str = ''
if self.default is not None:
default_str = f' ELSE {self.default.to_string()}'
return f"CASE {rules_str}{default_str} END"

arg_str = ''
if self.arg is not None:
arg_str = f'{self.arg.to_string()} '
return f"CASE {arg_str}{rules_str}{default_str} END"
13 changes: 10 additions & 3 deletions mindsdb_sql/parser/ast/select/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs):


class WindowFunction(ASTNode):
def __init__(self, function, partition=None, order_by=None, alias=None):
def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None):
super().__init__()
self.function = function
self.partition = partition
self.order_by = order_by
self.alias = alias
self.modifier = modifier

def to_tree(self, *args, level=0, **kwargs):
fnc_str = self.function.to_tree(level=level+2)
Expand Down Expand Up @@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs):
alias_str = self.alias.to_string()
else:
alias_str = ''
return f'{fnc_str} over({partition_str} {order_str}) {alias_str}'
modifier_str = ' ' + self.modifier if self.modifier else ''
return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}'


class Object(ASTNode):
Expand Down Expand Up @@ -177,7 +179,12 @@ def __init__(self, info):
super().__init__(op='interval', args=[info, ])

def get_string(self, *args, **kwargs):
return f'INTERVAL {self.args[0]}'

arg = self.args[0]
items = arg.split(' ', maxsplit=1)
# quote first element
items[0] = f"'{items[0]}'"
return "INTERVAL " + " ".join(items)

def to_tree(self, *args, level=0, **kwargs):
return self.get_string( *args, **kwargs)
Expand Down
22 changes: 19 additions & 3 deletions mindsdb_sql/parser/ast/select/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from mindsdb_sql.parser.utils import indent


class Union(ASTNode):
class CombiningQuery(ASTNode):
operation = None

def __init__(self,
left,
Expand All @@ -24,7 +25,8 @@ def to_tree(self, *args, level=0, **kwargs):
left_str = f'\n{ind1}left=\n{self.left.to_tree(level=level + 2)},'
right_str = f'\n{ind1}right=\n{self.right.to_tree(level=level + 2)},'

out_str = f'{ind}Union(unique={repr(self.unique)},' \
cls_name = self.__class__.__name__
out_str = f'{ind}{cls_name}(unique={repr(self.unique)},' \
f'{left_str}' \
f'{right_str}' \
f'\n{ind})'
Expand All @@ -33,7 +35,21 @@ def to_tree(self, *args, level=0, **kwargs):
def get_string(self, *args, **kwargs):
left_str = str(self.left)
right_str = str(self.right)
keyword = 'UNION' if self.unique else 'UNION ALL'
keyword = self.operation
if not self.unique:
keyword += ' ALL'
out_str = f"""{left_str}\n{keyword}\n{right_str}"""

return out_str


class Union(CombiningQuery):
operation = 'UNION'


class Intersect(CombiningQuery):
operation = 'INTERSECT'


class Except(CombiningQuery):
operation = 'EXCEPT'
4 changes: 3 additions & 1 deletion mindsdb_sql/parser/dialects/mindsdb/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class MindsDBLexer(Lexer):

JOIN, INNER, OUTER, CROSS, LEFT, RIGHT, ON,

UNION, ALL,
UNION, ALL, INTERSECT, EXCEPT,

# CASE
CASE, ELSE, END, THEN, WHEN,
Expand Down Expand Up @@ -238,6 +238,8 @@ class MindsDBLexer(Lexer):
# UNION

UNION = r'\bUNION\b'
INTERSECT = r'\bINTERSECT\b'
EXCEPT = r'\bEXCEPT\b'
ALL = r'\bALL\b'

# CASE
Expand Down
65 changes: 50 additions & 15 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class MindsDBParser(Parser):
'drop_dataset',
'select',
'insert',
'union',
'update',
'delete',
'evaluate',
Expand Down Expand Up @@ -614,10 +615,13 @@ def update(self, p):

# INSERT
@_('INSERT INTO identifier LPAREN column_list RPAREN select',
'INSERT INTO identifier select')
'INSERT INTO identifier LPAREN column_list RPAREN union',
'INSERT INTO identifier select',
'INSERT INTO identifier union')
def insert(self, p):
columns = getattr(p, 'column_list', None)
return Insert(table=p.identifier, columns=columns, from_select=p.select)
query = p.select if hasattr(p, 'select') else p.union
return Insert(table=p.identifier, columns=columns, from_select=query)

@_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set',
'INSERT INTO identifier VALUES expr_list_set')
Expand Down Expand Up @@ -998,19 +1002,35 @@ def database_engine(self, p):
engine = p.string
return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty}

# UNION / UNION ALL
@_('select UNION select')
def select(self, p):
return Union(left=p[0], right=p[2], unique=True)

@_('select UNION ALL select')
def select(self, p):
return Union(left=p[0], right=p[3], unique=False)
# Combining
@_('select UNION select',
'union UNION select',
'select UNION ALL select',
'union UNION ALL select')
def union(self, p):
unique = not hasattr(p, 'ALL')
return Union(left=p[0], right=p[2] if unique else p[3], unique=unique)

@_('select INTERSECT select',
'union INTERSECT select',
'select INTERSECT ALL select',
'union INTERSECT ALL select')
def union(self, p):
unique = not hasattr(p, 'ALL')
return Intersect(left=p[0], right=p[2] if unique else p[3], unique=unique)
@_('select EXCEPT select',
'union EXCEPT select',
'select EXCEPT ALL select',
'union EXCEPT ALL select')
def union(self, p):
unique = not hasattr(p, 'ALL')
return Except(left=p[0], right=p[2] if unique else p[3], unique=unique)

# tableau
@_('LPAREN select RPAREN')
@_('LPAREN union RPAREN')
def select(self, p):
return p.select
return p[1]

# WITH
@_('ctes select')
Expand All @@ -1030,13 +1050,14 @@ def ctes(self, p):
]
return ctes

@_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN')
@_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN',
'WITH identifier cte_columns_or_nothing AS LPAREN union RPAREN')
def ctes(self, p):
return [
CommonTableExpression(
name=p.identifier,
columns=p.cte_columns_or_nothing,
query=p.select)
query=p[5])
]

@_('empty')
Expand Down Expand Up @@ -1331,6 +1352,15 @@ def column_list(self, p):
def case(self, p):
return Case(rules=p.case_conditions, default=getattr(p, 'expr', None))

@_('CASE expr case_conditions ELSE expr END',
'CASE expr case_conditions END')
def case(self, p):
if hasattr(p, 'expr'):
arg, default = p.expr, None
else:
arg, default = p.expr0, p.expr1
return Case(rules=p.case_conditions, default=default, arg=arg)

@_('case_condition',
'case_conditions case_condition')
def case_conditions(self, p):
Expand All @@ -1343,13 +1373,18 @@ def case_condition(self, p):
return [p.expr0, p.expr1]

# Window function
@_('function OVER LPAREN window RPAREN')
@_('expr OVER LPAREN window RPAREN',
'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN')
def window_function(self, p):

modifier = None
if hasattr(p, 'BETWEEN'):
modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}'
return WindowFunction(
function=p.function,
function=p.expr,
order_by=p.window.get('order_by'),
partition=p.window.get('partition'),
modifier=modifier,
)

@_('window PARTITION_BY expr_list')
Expand Down
Loading

0 comments on commit 410c2f9

Please sign in to comment.