Skip to content

Commit

Permalink
Merge pull request #413 from mindsdb/cte-support
Browse files Browse the repository at this point in the history
Cte support
  • Loading branch information
ea-rus authored Nov 8, 2024
2 parents 5525c2c + 82cae5a commit 57ba416
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 12 deletions.
8 changes: 5 additions & 3 deletions mindsdb_sql/planner/plan_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def plan(self, query):
query2 = copy.deepcopy(query)
query2.from_table = None
query2.using = None
query2.cte = None
sup_select = QueryStep(query2, from_table=join_step.result)
self.planner.plan.add_step(sup_select)
return sup_select
Expand Down Expand Up @@ -375,7 +376,9 @@ def process_subselect(self, item):
self.step_stack.append(step2)

def process_table(self, item, query_in):
query2 = Select(from_table=item.table, targets=[Star()])
table = copy.deepcopy(item.table)
table.parts.insert(0, item.integration)
query2 = Select(from_table=table, targets=[Star()])
# parts = tuple(map(str.lower, table_name.parts))
conditions = item.conditions
if 'or' in self.query_context['binary_ops']:
Expand Down Expand Up @@ -414,8 +417,7 @@ def process_table(self, item, query_in):
else:
query2.where = cond

# step = self.planner.get_integration_select_step(query2)
step = FetchDataframeStep(integration=item.integration, query=query2)
step = self.planner.get_integration_select_step(query2)
self.tables_fetch_step[item.index] = step

self.add_plan_step(step)
Expand Down
17 changes: 17 additions & 0 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(self,

self.statement = None

self.cte_results = {}

def is_predictor(self, identifier):
if not isinstance(identifier, Identifier):
return False
Expand Down Expand Up @@ -158,6 +160,12 @@ def get_integration_select_step(self, select):
else:
integration_name, table = self.resolve_database_table(select.from_table)

# is it CTE?
table_name = table.parts[-1]
if integration_name == self.default_namespace and table_name in self.cte_results:
select.from_table = None
return SubSelectStep(select, self.cte_results[table_name], table_name=table_name)

fetch_df_select = copy.deepcopy(select)
self.prepare_integration_select(integration_name, fetch_df_select)

Expand Down Expand Up @@ -663,10 +671,19 @@ def plan_delete(self, query: Delete):
where=query.where
))

def plan_cte(self, query):
for cte in query.cte:
step = self.plan_select(cte.query)
name = cte.name.parts[-1]
self.cte_results[name] = step.result

def plan_select(self, query, integration=None):
if isinstance(query, Union):
return self.plan_union(query, integration=integration)

if query.cte is not None:
self.plan_cte(query)

from_table = query.from_table

if isinstance(from_table, Identifier):
Expand Down
2 changes: 2 additions & 0 deletions mindsdb_sql/planner/query_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def find_predictors(node, is_table, **kwargs):

elif column.name is not None:
# is Identifier
if isinstance(column.name, ast.Star):
continue
col_name = column.name.upper()
if column.table is not None:
table = column.table
Expand Down
8 changes: 0 additions & 8 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,6 @@ def test_join_predictor_plan_limit(self):
# plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}})

def test_join_predictor_plan_complex_query(self):
query = Select(targets=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')],
from_table=Join(left=Identifier('int.tab'),
right=Identifier('mindsdb.pred'),
join_type=JoinType.INNER_JOIN,
implicit=True),
group_by=[Identifier('tab.asset')],
having=BinaryOperation('=', args=[Identifier('tab.asset'), Constant('bitcoin')])
)

sql = """
select t.asset, t.time, m.predicted
Expand Down
33 changes: 32 additions & 1 deletion tests/test_planner/test_join_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,4 +489,35 @@ def test_join_one_integration(self):
)
plan = plan_query(query, integrations=['int'], default_namespace='int')

assert plan.steps == expected_plan.steps
assert plan.steps == expected_plan.steps

def test_cte(self):
query = parse_sql('''
with t1 as (
select * from int1.tbl1
)
select t1.id, t2.* from t1
join int2.tbl2 t2 on t1.id>t2.id
''')

subquery = copy.deepcopy(query)
subquery.from_table = None

plan = plan_query(query, integrations=['int1', 'int2'], default_namespace='mindsdb')

expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1')),
SubSelectStep(dataframe=Result(0), query=Select(targets=[Star()]), table_name='t1'),
FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl2 as t2')),
JoinStep(left=Result(1),
right=Result(2),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='>', args=[Identifier('t1.id'), Identifier('t2.id')]),
join_type=JoinType.JOIN)),
QueryStep(parse_sql('SELECT t1.`id`, t2.*'), from_table=Result(3)),
]
)

assert plan.steps == expected_plan.steps

0 comments on commit 57ba416

Please sign in to comment.