Skip to content

Commit

Permalink
cte support
Browse files Browse the repository at this point in the history
  • Loading branch information
ea-rus committed Nov 1, 2024
1 parent c50e06c commit 61d8141
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 25 deletions.
6 changes: 3 additions & 3 deletions mindsdb_sql/planner/plan_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def process_subselect(self, item):
self.step_stack.append(step2)

def process_table(self, item, query_in):
query2 = Select(from_table=item.table, targets=[Star()])
table = Identifier(parts=[item.integration] + item.table.parts)
query2 = Select(from_table=table, targets=[Star()])
# parts = tuple(map(str.lower, table_name.parts))
conditions = item.conditions
if 'or' in self.query_context['binary_ops']:
Expand Down Expand Up @@ -404,8 +405,7 @@ def process_table(self, item, query_in):
else:
query2.where = cond

# step = self.planner.get_integration_select_step(query2)
step = FetchDataframeStep(integration=item.integration, query=query2)
step = self.planner.get_integration_select_step(query2)
self.add_plan_step(step)
self.step_stack.append(step)

Expand Down
17 changes: 17 additions & 0 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(self,

self.statement = None

self.cte_results = {}

def is_predictor(self, identifier):
if not isinstance(identifier, Identifier):
return False
Expand Down Expand Up @@ -158,6 +160,12 @@ def get_integration_select_step(self, select):
else:
integration_name, table = self.resolve_database_table(select.from_table)

# is it CTE?
table_name = table.parts[-1]
if integration_name == self.default_namespace and table_name in self.cte_results:
select.from_table = None
return SubSelectStep(select, self.cte_results[table_name], table_name=table_name)

fetch_df_select = copy.deepcopy(select)
self.prepare_integration_select(integration_name, fetch_df_select)

Expand Down Expand Up @@ -663,10 +671,19 @@ def plan_delete(self, query: Delete):
where=query.where
))

def plan_cte(self, query):
for cte in query.cte:
step = self.plan_select(cte.query)
name = cte.name.parts[-1]
self.cte_results[name] = step.result

def plan_select(self, query, integration=None):
if isinstance(query, Union):
return self.plan_union(query, integration=integration)

if query.cte is not None:
self.plan_cte(query)

from_table = query.from_table

if isinstance(from_table, Identifier):
Expand Down
30 changes: 11 additions & 19 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_join_predictor_plan_aliases(self):
steps=[
FetchDataframeStep(integration='int',
query=Select(targets=[Star()],
from_table=Identifier('tab1', alias=Identifier('ta'))),
from_table=Identifier('tab1')),
),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('tb'))),
JoinStep(left=Result(0), right=Result(1),
Expand Down Expand Up @@ -144,14 +144,6 @@ def test_join_predictor_plan_limit(self):
# plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}})

def test_join_predictor_plan_complex_query(self):
query = Select(targets=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')],
from_table=Join(left=Identifier('int.tab'),
right=Identifier('mindsdb.pred'),
join_type=JoinType.INNER_JOIN,
implicit=True),
group_by=[Identifier('tab.asset')],
having=BinaryOperation('=', args=[Identifier('tab.asset'), Constant('bitcoin')])
)

sql = """
select t.asset, t.time, m.predicted
Expand All @@ -172,7 +164,7 @@ def test_join_predictor_plan_complex_query(self):
steps=[
FetchDataframeStep(
integration='int',
query=parse_sql("select * from tab as t where col1 = 'x'")
query=parse_sql("select * from tab where col1 = 'x'")
),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('m'))),
JoinStep(left=Result(0), right=Result(1),
Expand Down Expand Up @@ -545,7 +537,7 @@ def test_where_using(self):
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as a where x=1 and y=3')),
query=parse_sql('select * from tab1 where x=1 and y=3')),
ApplyPredictorStep(
namespace='proj', dataframe=Result(0),
predictor=Identifier('pred.1', alias=Identifier('p')),
Expand Down Expand Up @@ -616,7 +608,7 @@ def test_model_param(self):
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as t where b=2')),
query=parse_sql('select * from tab1 where b=2')),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0),
predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 1}),
JoinStep(left=Result(0), right=Result(1),
Expand Down Expand Up @@ -649,9 +641,9 @@ def test_model_param(self):
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as t')),
query=parse_sql('select * from tab1')),
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab2 as t2')),
query=parse_sql('select * from tab2')),
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
Expand Down Expand Up @@ -680,7 +672,7 @@ def test_complex_subselect(self):
and t1.b=1 and t2.b=2 and t1.a = t2.a
'''

q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 ')
q_table2 = parse_sql('select * from tab2 where x=0 and b=2 ')
q_table2.where.args[0].args[1] = Parameter(Result(2))

subquery = parse_sql("""
Expand All @@ -707,7 +699,7 @@ def test_complex_subselect(self):
query=parse_sql('select a as a from tab4 where x=4')),
# tables
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as t1 where b=1')),
query=parse_sql('select * from tab1 where b=1')),
FetchDataframeStep(integration='int', query=q_table2),
JoinStep(left=Result(3), right=Result(4),
query=Join(left=Identifier('tab1'),
Expand Down Expand Up @@ -751,7 +743,7 @@ def test_model_join_model(self):
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as t')),
query=parse_sql('select * from tab1')),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0),
predictor=Identifier('pred', alias=Identifier('m')),
row_dict={ 'a': 2 }, params={ 'param1': 'a', 'param3': 'c' }),
Expand Down Expand Up @@ -790,7 +782,7 @@ def test_model_column_map(self):
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as a')),
query=parse_sql('select * from tab1')),
ApplyPredictorStep(
namespace='proj', dataframe=Result(0),
predictor=Identifier('pred.1', alias=Identifier('p')),
Expand Down Expand Up @@ -829,7 +821,7 @@ def test_partition(self):
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as a')),
query=parse_sql('select * from tab1')),
MapReduceStep(
values=Result(0),
step=[
Expand Down
6 changes: 3 additions & 3 deletions tests/test_planner/test_join_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def test_complex_join_tables(self):

expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1 as t1 where a=1')),
FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl2 as t2 where b=2')),
FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1 where a=1')),
FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl2 where b=2')),
JoinStep(left=Result(0),
right=Result(1),
query=Join(left=Identifier('tab1'),
Expand Down Expand Up @@ -389,7 +389,7 @@ def test_complex_join_tables_subselect(self):

expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1 as t1')),
FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1')),
FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl3')),
ApplyPredictorStep(namespace='proj', dataframe=Result(1),
predictor=Identifier('pred', alias=Identifier('m'))),
Expand Down

0 comments on commit 61d8141

Please sign in to comment.