Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser fixes #1 #415

Merged
merged 13 commits into from
Nov 11, 2024
14 changes: 12 additions & 2 deletions mindsdb_sql/parser/ast/select/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@


class Case(ASTNode):
def __init__(self, rules, default=None, *args, **kwargs):
def __init__(self, rules, default=None, arg=None, *args, **kwargs):
super().__init__(*args, **kwargs)

# structure:
# [
# [ condition, result ]
# ]
self.arg = arg
self.rules = rules
self.default = default

Expand All @@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs):
if self.default is not None:
default_str = f'{ind1}default => {self.default.to_string()}\n'

arg_str = ''
if self.arg is not None:
arg_str = f'{ind1}arg => {self.arg.to_string()}\n'

return f'{ind}Case(\n' \
f'{arg_str}'\
f'{rules_str}\n' \
f'{default_str}' \
f'{ind})'
Expand All @@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs):
default_str = ''
if self.default is not None:
default_str = f' ELSE {self.default.to_string()}'
return f"CASE {rules_str}{default_str} END"

arg_str = ''
if self.arg is not None:
arg_str = f'{self.arg.to_string()} '
return f"CASE {arg_str}{rules_str}{default_str} END"
13 changes: 10 additions & 3 deletions mindsdb_sql/parser/ast/select/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs):


class WindowFunction(ASTNode):
def __init__(self, function, partition=None, order_by=None, alias=None):
def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None):
super().__init__()
self.function = function
self.partition = partition
self.order_by = order_by
self.alias = alias
self.modifier = modifier

def to_tree(self, *args, level=0, **kwargs):
fnc_str = self.function.to_tree(level=level+2)
Expand Down Expand Up @@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs):
alias_str = self.alias.to_string()
else:
alias_str = ''
return f'{fnc_str} over({partition_str} {order_str}) {alias_str}'
modifier_str = ' ' + self.modifier if self.modifier else ''
return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}'


class Object(ASTNode):
Expand Down Expand Up @@ -177,7 +179,12 @@ def __init__(self, info):
super().__init__(op='interval', args=[info, ])

def get_string(self, *args, **kwargs):
return f'INTERVAL {self.args[0]}'

arg = self.args[0]
items = arg.split(' ', maxsplit=1)
# quote first element
items[0] = f"'{items[0]}'"
return "INTERVAL " + " ".join(items)

def to_tree(self, *args, level=0, **kwargs):
return self.get_string( *args, **kwargs)
Expand Down
18 changes: 16 additions & 2 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,15 @@ def column_list(self, p):
def case(self, p):
return Case(rules=p.case_conditions, default=getattr(p, 'expr', None))

@_('CASE expr case_conditions ELSE expr END',
'CASE expr case_conditions END')
def case(self, p):
if hasattr(p, 'expr'):
arg, default = p.expr, None
else:
arg, default = p.expr0, p.expr1
return Case(rules=p.case_conditions, default=default, arg=arg)

@_('case_condition',
'case_conditions case_condition')
def case_conditions(self, p):
Expand All @@ -1364,13 +1373,18 @@ def case_condition(self, p):
return [p.expr0, p.expr1]

# Window function
@_('function OVER LPAREN window RPAREN')
@_('expr OVER LPAREN window RPAREN',
'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN')
def window_function(self, p):

modifier = None
if hasattr(p, 'BETWEEN'):
modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}'
return WindowFunction(
function=p.function,
function=p.expr,
order_by=p.window.get('order_by'),
partition=p.window.get('partition'),
modifier=modifier,
)

@_('window PARTITION_BY expr_list')
Expand Down
23 changes: 23 additions & 0 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,19 @@ def find_objects(node, is_table, **kwargs):
mdb_entities.append(node)

query_traversal(query, find_objects)

# cte names are not mdb objects
if query.cte:
cte_names = [
cte.name.parts[-1]
for cte in query.cte
]
mdb_entities = [
item
for item in mdb_entities
if '.'.join(item.parts) not in cte_names
]

return {
'mdb_entities': mdb_entities,
'integrations': integrations,
Expand Down Expand Up @@ -672,6 +685,16 @@ def plan_delete(self, query: Delete):
))

def plan_cte(self, query):
query_info = self.get_query_info(query)

if (
len(query_info['integrations']) == 1
and len(query_info['mdb_entities']) == 0
and len(query_info['user_functions']) == 0
):
# single integration, will be planned later
return

for cte in query.cte:
step = self.plan_select(cte.query)
name = cte.name.parts[-1]
Expand Down
9 changes: 7 additions & 2 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,15 @@ def prepare_case(self, t: ast.Case):
conditions.append(
(self.to_expression(condition), self.to_expression(result))
)
default = None
if t.default is not None:
conditions.append(self.to_expression(t.default))
default = self.to_expression(t.default)

return sa.case(*conditions)
value = None
if t.arg is not None:
value = self.to_expression(t.arg)

return sa.case(*conditions, else_=default, value=value)

def to_function(self, t):
op = getattr(sa.func, t.op)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_parser/test_base_sql/test_select_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,40 @@ def test_case(self):
assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)

def test_case_simple_form(self):
sql = f'''SELECT
CASE R.DELETE_RULE
WHEN 'CASCADE' THEN 0
WHEN 'SET NULL' THEN 2
ELSE 3
END AS DELETE_RULE
FROM COLLATIONS'''
ast = parse_sql(sql)

expected_ast = Select(
targets=[
Case(
arg=Identifier('R.DELETE_RULE'),
rules=[
[
Constant('CASCADE'),
Constant(0)
],
[
Constant('SET NULL'),
Constant(2)
]
],
default=Constant(3),
alias=Identifier('DELETE_RULE')
)
],
from_table=Identifier('COLLATIONS')
)

assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)

def test_select_left(self):
sql = f'select left(a, 1) from tab1'
ast = parse_sql(sql)
Expand Down Expand Up @@ -1152,3 +1186,23 @@ def test_table_double_quote(self):

ast = parse_sql(sql)
assert str(ast) == str(expected_ast)

def test_window_function_mindsdb(self):

# modifier
query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 "
expected_ast = Select(
targets=[
WindowFunction(
function=Function(op='sum', args=[Identifier('col0')]),
partition=[Identifier('col1')],
order_by=[OrderBy(field=Identifier('col2'))],
modifier='rows BETWEEN unbounded preceding AND current row'
)
],
from_table=Identifier('table1')
)
ast = parse_sql(query)
assert str(ast) == str(expected_ast)
assert ast.to_tree() == expected_ast.to_tree()

43 changes: 42 additions & 1 deletion tests/test_planner/test_integration_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def test_select_from_table_subselect_api_integration(self):
plan = plan_query(
query,
integrations=[{'name': 'int1', 'class_type': 'api', 'type': 'data'}],
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}]
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}],
)

assert plan.steps == expected_plan.steps
Expand Down Expand Up @@ -583,6 +583,47 @@ def test_select_from_table_subselect_sql_integration(self):

assert plan.steps == expected_plan.steps

def test_select_from_single_integration(self):
sql_parsed = '''
with tab2 as (
select * from int1.tabl2
)
select x from tab2
join int1.tab1 on 0=0
where x1 in (select id from int1.tab1)
limit 1
'''

sql_integration = '''
with tab2 as (
select * from tabl2
)
select x from tab2
join tab1 on 0=0
where x1 in (select id as id from tab1)
limit 1
'''
query = parse_sql(sql_parsed, dialect='mindsdb')

expected_plan = QueryPlan(
predictor_namespace='mindsdb',
steps=[
FetchDataframeStep(
integration='int1',
query=parse_sql(sql_integration),
),
],
)

plan = plan_query(
query,
integrations=[{'name': 'int1', 'class_type': 'sql', 'type': 'data'}],
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}],
default_namespace='mindsdb',
)

assert plan.steps == expected_plan.steps

def test_delete_from_table_subselect_api_integration(self):
query = parse_sql('''
delete from int1.tab1
Expand Down