diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index ea86784..1bb3afa 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -112,6 +112,7 @@ def plan(self, query): query2 = copy.deepcopy(query) query2.from_table = None query2.using = None + query2.cte = None sup_select = QueryStep(query2, from_table=join_step.result) self.planner.plan.add_step(sup_select) return sup_select @@ -375,7 +376,9 @@ def process_subselect(self, item): self.step_stack.append(step2) def process_table(self, item, query_in): - query2 = Select(from_table=item.table, targets=[Star()]) + table = copy.deepcopy(item.table) + table.parts.insert(0, item.integration) + query2 = Select(from_table=table, targets=[Star()]) # parts = tuple(map(str.lower, table_name.parts)) conditions = item.conditions if 'or' in self.query_context['binary_ops']: @@ -414,8 +417,7 @@ def process_table(self, item, query_in): else: query2.where = cond - # step = self.planner.get_integration_select_step(query2) - step = FetchDataframeStep(integration=item.integration, query=query2) + step = self.planner.get_integration_select_step(query2) self.tables_fetch_step[item.index] = step self.add_plan_step(step) diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index 90b697b..f2d08d1 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -86,6 +86,8 @@ def __init__(self, self.statement = None + self.cte_results = {} + def is_predictor(self, identifier): if not isinstance(identifier, Identifier): return False @@ -158,6 +160,12 @@ def get_integration_select_step(self, select): else: integration_name, table = self.resolve_database_table(select.from_table) + # is it CTE? + table_name = table.parts[-1] + if integration_name == self.default_namespace and table_name in self.cte_results: + select.from_table = None + return SubSelectStep(select, self.cte_results[table_name], table_name=table_name) + fetch_df_select = copy.deepcopy(select) self.prepare_integration_select(integration_name, fetch_df_select) @@ -663,10 +671,19 @@ def plan_delete(self, query: Delete): where=query.where )) + def plan_cte(self, query): + for cte in query.cte: + step = self.plan_select(cte.query) + name = cte.name.parts[-1] + self.cte_results[name] = step.result + def plan_select(self, query, integration=None): if isinstance(query, Union): return self.plan_union(query, integration=integration) + if query.cte is not None: + self.plan_cte(query) + from_table = query.from_table if isinstance(from_table, Identifier): diff --git a/mindsdb_sql/planner/query_prepare.py b/mindsdb_sql/planner/query_prepare.py index 26a10c8..9614dfc 100644 --- a/mindsdb_sql/planner/query_prepare.py +++ b/mindsdb_sql/planner/query_prepare.py @@ -348,6 +348,8 @@ def find_predictors(node, is_table, **kwargs): elif column.name is not None: # is Identifier + if isinstance(column.name, ast.Star): + continue col_name = column.name.upper() if column.table is not None: table = column.table diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index 8dbab08..94afda8 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -142,14 +142,6 @@ def test_join_predictor_plan_limit(self): # plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}}) def test_join_predictor_plan_complex_query(self): - query = Select(targets=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True), - group_by=[Identifier('tab.asset')], - having=BinaryOperation('=', args=[Identifier('tab.asset'), Constant('bitcoin')]) - ) sql = """ select t.asset, t.time, m.predicted diff --git a/tests/test_planner/test_join_tables.py b/tests/test_planner/test_join_tables.py index f624fd9..b85bafb 100644 --- a/tests/test_planner/test_join_tables.py +++ b/tests/test_planner/test_join_tables.py @@ -489,4 +489,35 @@ def test_join_one_integration(self): ) plan = plan_query(query, integrations=['int'], default_namespace='int') - assert plan.steps == expected_plan.steps \ No newline at end of file + assert plan.steps == expected_plan.steps + + def test_cte(self): + query = parse_sql(''' + with t1 as ( + select * from int1.tbl1 + ) + select t1.id, t2.* from t1 + join int2.tbl2 t2 on t1.id>t2.id + ''') + + subquery = copy.deepcopy(query) + subquery.from_table = None + + plan = plan_query(query, integrations=['int1', 'int2'], default_namespace='mindsdb') + + expected_plan = QueryPlan( + steps=[ + FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1')), + SubSelectStep(dataframe=Result(0), query=Select(targets=[Star()]), table_name='t1'), + FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl2 as t2')), + JoinStep(left=Result(1), + right=Result(2), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), + condition=BinaryOperation(op='>', args=[Identifier('t1.id'), Identifier('t2.id')]), + join_type=JoinType.JOIN)), + QueryStep(parse_sql('SELECT t1.`id`, t2.*'), from_table=Result(3)), + ] + ) + + assert plan.steps == expected_plan.steps