From 61d8141042fdb734c71207206e0221f83fca033c Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 1 Nov 2024 22:27:34 +0300 Subject: [PATCH] cte support --- mindsdb_sql/planner/plan_join.py | 6 ++--- mindsdb_sql/planner/query_planner.py | 17 +++++++++++++ tests/test_planner/test_join_predictor.py | 30 +++++++++-------------- tests/test_planner/test_join_tables.py | 6 ++--- 4 files changed, 34 insertions(+), 25 deletions(-) diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index 3089dc6..0b02df8 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -368,7 +368,8 @@ def process_subselect(self, item): self.step_stack.append(step2) def process_table(self, item, query_in): - query2 = Select(from_table=item.table, targets=[Star()]) + table = Identifier(parts=[item.integration] + item.table.parts) + query2 = Select(from_table=table, targets=[Star()]) # parts = tuple(map(str.lower, table_name.parts)) conditions = item.conditions if 'or' in self.query_context['binary_ops']: @@ -404,8 +405,7 @@ def process_table(self, item, query_in): else: query2.where = cond - # step = self.planner.get_integration_select_step(query2) - step = FetchDataframeStep(integration=item.integration, query=query2) + step = self.planner.get_integration_select_step(query2) self.add_plan_step(step) self.step_stack.append(step) diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index 90b697b..f2d08d1 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -86,6 +86,8 @@ def __init__(self, self.statement = None + self.cte_results = {} + def is_predictor(self, identifier): if not isinstance(identifier, Identifier): return False @@ -158,6 +160,12 @@ def get_integration_select_step(self, select): else: integration_name, table = self.resolve_database_table(select.from_table) + # is it CTE? + table_name = table.parts[-1] + if integration_name == self.default_namespace and table_name in self.cte_results: + select.from_table = None + return SubSelectStep(select, self.cte_results[table_name], table_name=table_name) + fetch_df_select = copy.deepcopy(select) self.prepare_integration_select(integration_name, fetch_df_select) @@ -663,10 +671,19 @@ def plan_delete(self, query: Delete): where=query.where )) + def plan_cte(self, query): + for cte in query.cte: + step = self.plan_select(cte.query) + name = cte.name.parts[-1] + self.cte_results[name] = step.result + def plan_select(self, query, integration=None): if isinstance(query, Union): return self.plan_union(query, integration=integration) + if query.cte is not None: + self.plan_cte(query) + from_table = query.from_table if isinstance(from_table, Identifier): diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index 40c792a..6970a97 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -62,7 +62,7 @@ def test_join_predictor_plan_aliases(self): steps=[ FetchDataframeStep(integration='int', query=Select(targets=[Star()], - from_table=Identifier('tab1', alias=Identifier('ta'))), + from_table=Identifier('tab1')), ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('tb'))), JoinStep(left=Result(0), right=Result(1), @@ -144,14 +144,6 @@ def test_join_predictor_plan_limit(self): # plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}}) def test_join_predictor_plan_complex_query(self): - query = Select(targets=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True), - group_by=[Identifier('tab.asset')], - having=BinaryOperation('=', args=[Identifier('tab.asset'), Constant('bitcoin')]) - ) sql = """ select t.asset, t.time, m.predicted @@ -172,7 +164,7 @@ def test_join_predictor_plan_complex_query(self): steps=[ FetchDataframeStep( integration='int', - query=parse_sql("select * from tab as t where col1 = 'x'") + query=parse_sql("select * from tab where col1 = 'x'") ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('m'))), JoinStep(left=Result(0), right=Result(1), @@ -545,7 +537,7 @@ def test_where_using(self): expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as a where x=1 and y=3')), + query=parse_sql('select * from tab1 where x=1 and y=3')), ApplyPredictorStep( namespace='proj', dataframe=Result(0), predictor=Identifier('pred.1', alias=Identifier('p')), @@ -616,7 +608,7 @@ def test_model_param(self): expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t where b=2')), + query=parse_sql('select * from tab1 where b=2')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 1}), JoinStep(left=Result(0), right=Result(1), @@ -649,9 +641,9 @@ def test_model_param(self): expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t')), + query=parse_sql('select * from tab1')), FetchDataframeStep(integration='int', - query=parse_sql('select * from tab2 as t2')), + query=parse_sql('select * from tab2')), JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), @@ -680,7 +672,7 @@ def test_complex_subselect(self): and t1.b=1 and t2.b=2 and t1.a = t2.a ''' - q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 ') + q_table2 = parse_sql('select * from tab2 where x=0 and b=2 ') q_table2.where.args[0].args[1] = Parameter(Result(2)) subquery = parse_sql(""" @@ -707,7 +699,7 @@ def test_complex_subselect(self): query=parse_sql('select a as a from tab4 where x=4')), # tables FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t1 where b=1')), + query=parse_sql('select * from tab1 where b=1')), FetchDataframeStep(integration='int', query=q_table2), JoinStep(left=Result(3), right=Result(4), query=Join(left=Identifier('tab1'), @@ -751,7 +743,7 @@ def test_model_join_model(self): expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t')), + query=parse_sql('select * from tab1')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('m')), row_dict={ 'a': 2 }, params={ 'param1': 'a', 'param3': 'c' }), @@ -790,7 +782,7 @@ def test_model_column_map(self): expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as a')), + query=parse_sql('select * from tab1')), ApplyPredictorStep( namespace='proj', dataframe=Result(0), predictor=Identifier('pred.1', alias=Identifier('p')), @@ -829,7 +821,7 @@ def test_partition(self): expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as a')), + query=parse_sql('select * from tab1')), MapReduceStep( values=Result(0), step=[ diff --git a/tests/test_planner/test_join_tables.py b/tests/test_planner/test_join_tables.py index f44483c..a9cbf3a 100644 --- a/tests/test_planner/test_join_tables.py +++ b/tests/test_planner/test_join_tables.py @@ -342,8 +342,8 @@ def test_complex_join_tables(self): expected_plan = QueryPlan( steps=[ - FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1 as t1 where a=1')), - FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl2 as t2 where b=2')), + FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1 where a=1')), + FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl2 where b=2')), JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), @@ -389,7 +389,7 @@ def test_complex_join_tables_subselect(self): expected_plan = QueryPlan( steps=[ - FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1 as t1')), + FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1')), FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl3')), ApplyPredictorStep(namespace='proj', dataframe=Result(1), predictor=Identifier('pred', alias=Identifier('m'))),