Skip to content

Commit

Permalink
Merge pull request #415 from mindsdb/window-fix
Browse files Browse the repository at this point in the history
Parser fixes #1
  • Loading branch information
ea-rus authored Nov 11, 2024
2 parents fbb913d + fce61dc commit fbc4315
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 10 deletions.
14 changes: 12 additions & 2 deletions mindsdb_sql/parser/ast/select/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@


class Case(ASTNode):
def __init__(self, rules, default=None, *args, **kwargs):
def __init__(self, rules, default=None, arg=None, *args, **kwargs):
super().__init__(*args, **kwargs)

# structure:
# [
# [ condition, result ]
# ]
self.arg = arg
self.rules = rules
self.default = default

Expand All @@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs):
if self.default is not None:
default_str = f'{ind1}default => {self.default.to_string()}\n'

arg_str = ''
if self.arg is not None:
arg_str = f'{ind1}arg => {self.arg.to_string()}\n'

return f'{ind}Case(\n' \
f'{arg_str}'\
f'{rules_str}\n' \
f'{default_str}' \
f'{ind})'
Expand All @@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs):
default_str = ''
if self.default is not None:
default_str = f' ELSE {self.default.to_string()}'
return f"CASE {rules_str}{default_str} END"

arg_str = ''
if self.arg is not None:
arg_str = f'{self.arg.to_string()} '
return f"CASE {arg_str}{rules_str}{default_str} END"
13 changes: 10 additions & 3 deletions mindsdb_sql/parser/ast/select/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs):


class WindowFunction(ASTNode):
def __init__(self, function, partition=None, order_by=None, alias=None):
def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None):
super().__init__()
self.function = function
self.partition = partition
self.order_by = order_by
self.alias = alias
self.modifier = modifier

def to_tree(self, *args, level=0, **kwargs):
fnc_str = self.function.to_tree(level=level+2)
Expand Down Expand Up @@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs):
alias_str = self.alias.to_string()
else:
alias_str = ''
return f'{fnc_str} over({partition_str} {order_str}) {alias_str}'
modifier_str = ' ' + self.modifier if self.modifier else ''
return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}'


class Object(ASTNode):
Expand Down Expand Up @@ -177,7 +179,12 @@ def __init__(self, info):
super().__init__(op='interval', args=[info, ])

def get_string(self, *args, **kwargs):
return f'INTERVAL {self.args[0]}'

arg = self.args[0]
items = arg.split(' ', maxsplit=1)
# quote first element
items[0] = f"'{items[0]}'"
return "INTERVAL " + " ".join(items)

def to_tree(self, *args, level=0, **kwargs):
return self.get_string( *args, **kwargs)
Expand Down
18 changes: 16 additions & 2 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,15 @@ def column_list(self, p):
def case(self, p):
return Case(rules=p.case_conditions, default=getattr(p, 'expr', None))

@_('CASE expr case_conditions ELSE expr END',
'CASE expr case_conditions END')
def case(self, p):
if hasattr(p, 'expr'):
arg, default = p.expr, None
else:
arg, default = p.expr0, p.expr1
return Case(rules=p.case_conditions, default=default, arg=arg)

@_('case_condition',
'case_conditions case_condition')
def case_conditions(self, p):
Expand All @@ -1364,13 +1373,18 @@ def case_condition(self, p):
return [p.expr0, p.expr1]

# Window function
@_('function OVER LPAREN window RPAREN')
@_('expr OVER LPAREN window RPAREN',
'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN')
def window_function(self, p):

modifier = None
if hasattr(p, 'BETWEEN'):
modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}'
return WindowFunction(
function=p.function,
function=p.expr,
order_by=p.window.get('order_by'),
partition=p.window.get('partition'),
modifier=modifier,
)

@_('window PARTITION_BY expr_list')
Expand Down
23 changes: 23 additions & 0 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,19 @@ def find_objects(node, is_table, **kwargs):
mdb_entities.append(node)

query_traversal(query, find_objects)

# cte names are not mdb objects
if query.cte:
cte_names = [
cte.name.parts[-1]
for cte in query.cte
]
mdb_entities = [
item
for item in mdb_entities
if '.'.join(item.parts) not in cte_names
]

return {
'mdb_entities': mdb_entities,
'integrations': integrations,
Expand Down Expand Up @@ -672,6 +685,16 @@ def plan_delete(self, query: Delete):
))

def plan_cte(self, query):
query_info = self.get_query_info(query)

if (
len(query_info['integrations']) == 1
and len(query_info['mdb_entities']) == 0
and len(query_info['user_functions']) == 0
):
# single integration, will be planned later
return

for cte in query.cte:
step = self.plan_select(cte.query)
name = cte.name.parts[-1]
Expand Down
9 changes: 7 additions & 2 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,15 @@ def prepare_case(self, t: ast.Case):
conditions.append(
(self.to_expression(condition), self.to_expression(result))
)
default = None
if t.default is not None:
conditions.append(self.to_expression(t.default))
default = self.to_expression(t.default)

return sa.case(*conditions)
value = None
if t.arg is not None:
value = self.to_expression(t.arg)

return sa.case(*conditions, else_=default, value=value)

def to_function(self, t):
op = getattr(sa.func, t.op)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_parser/test_base_sql/test_select_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,40 @@ def test_case(self):
assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)

def test_case_simple_form(self):
sql = f'''SELECT
CASE R.DELETE_RULE
WHEN 'CASCADE' THEN 0
WHEN 'SET NULL' THEN 2
ELSE 3
END AS DELETE_RULE
FROM COLLATIONS'''
ast = parse_sql(sql)

expected_ast = Select(
targets=[
Case(
arg=Identifier('R.DELETE_RULE'),
rules=[
[
Constant('CASCADE'),
Constant(0)
],
[
Constant('SET NULL'),
Constant(2)
]
],
default=Constant(3),
alias=Identifier('DELETE_RULE')
)
],
from_table=Identifier('COLLATIONS')
)

assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)

def test_select_left(self):
sql = f'select left(a, 1) from tab1'
ast = parse_sql(sql)
Expand Down Expand Up @@ -1152,3 +1186,23 @@ def test_table_double_quote(self):

ast = parse_sql(sql)
assert str(ast) == str(expected_ast)

def test_window_function_mindsdb(self):

# modifier
query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 "
expected_ast = Select(
targets=[
WindowFunction(
function=Function(op='sum', args=[Identifier('col0')]),
partition=[Identifier('col1')],
order_by=[OrderBy(field=Identifier('col2'))],
modifier='rows BETWEEN unbounded preceding AND current row'
)
],
from_table=Identifier('table1')
)
ast = parse_sql(query)
assert str(ast) == str(expected_ast)
assert ast.to_tree() == expected_ast.to_tree()

43 changes: 42 additions & 1 deletion tests/test_planner/test_integration_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def test_select_from_table_subselect_api_integration(self):
plan = plan_query(
query,
integrations=[{'name': 'int1', 'class_type': 'api', 'type': 'data'}],
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}]
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}],
)

assert plan.steps == expected_plan.steps
Expand Down Expand Up @@ -583,6 +583,47 @@ def test_select_from_table_subselect_sql_integration(self):

assert plan.steps == expected_plan.steps

def test_select_from_single_integration(self):
sql_parsed = '''
with tab2 as (
select * from int1.tabl2
)
select x from tab2
join int1.tab1 on 0=0
where x1 in (select id from int1.tab1)
limit 1
'''

sql_integration = '''
with tab2 as (
select * from tabl2
)
select x from tab2
join tab1 on 0=0
where x1 in (select id as id from tab1)
limit 1
'''
query = parse_sql(sql_parsed, dialect='mindsdb')

expected_plan = QueryPlan(
predictor_namespace='mindsdb',
steps=[
FetchDataframeStep(
integration='int1',
query=parse_sql(sql_integration),
),
],
)

plan = plan_query(
query,
integrations=[{'name': 'int1', 'class_type': 'sql', 'type': 'data'}],
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}],
default_namespace='mindsdb',
)

assert plan.steps == expected_plan.steps

def test_delete_from_table_subselect_api_integration(self):
query = parse_sql('''
delete from int1.tab1
Expand Down

0 comments on commit fbc4315

Please sign in to comment.