Skip to content

Commit

Permalink
support: model.col = (subselect)
Browse files Browse the repository at this point in the history
  • Loading branch information
ea-rus committed Nov 30, 2023
1 parent c4c54de commit 6512e8c
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 101 deletions.
86 changes: 43 additions & 43 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_deepest_select,
recursively_extract_column_values,
recursively_check_join_identifiers_for_ambiguity,
query_traversal)
query_traversal, filters_to_bin_op)
from mindsdb_sql.planner.query_plan import QueryPlan
from mindsdb_sql.planner import utils
from .query_prepare import PreparedStatementPlanner
Expand Down Expand Up @@ -167,7 +167,7 @@ def _prepare_integration_select(node, is_table, is_target, parent_query, **kwarg
if isinstance(last_part, str):
node.alias = Identifier(parts=[node.parts[-1]])

utils.query_traversal(query, _prepare_integration_select)
query_traversal(query, _prepare_integration_select)

def get_integration_select_step(self, select):
integration_name, table = self.resolve_database_table(select.from_table)
Expand Down Expand Up @@ -230,7 +230,7 @@ def find_predictors(node, is_table, **kwargs):
if isinstance(node, ast.NativeQuery) or isinstance(node, ast.Data):
mdb_entities.append(node)

utils.query_traversal(query, find_predictors)
query_traversal(query, find_predictors)
return {'mdb_entities': mdb_entities, 'integrations': integrations, 'predictors': predictors}

def get_nested_selects_plan_fnc(self, main_integration, force=False):
Expand Down Expand Up @@ -273,8 +273,8 @@ def plan_select_identifier(self, query):
is_api_db = self.integrations.get(main_integration, {}).get('class_type') == 'api'

find_selects = self.get_nested_selects_plan_fnc(main_integration, force=is_api_db)
query.targets = utils.query_traversal(query.targets, find_selects)
utils.query_traversal(query.where, find_selects)
query.targets = query_traversal(query.targets, find_selects)
query_traversal(query.where, find_selects)

# get info of updated query
query_info = self.get_query_info(query)
Expand Down Expand Up @@ -427,7 +427,7 @@ def split_filters(node, **kwargs):
if not isinstance(arg1, Identifier):
arg1, arg2 = arg2, arg1

if isinstance(arg1, Identifier) and isinstance(arg2, Constant) and len(arg1.parts) > 1:
if isinstance(arg1, Identifier) and isinstance(arg2, (Constant, Parameter)) and len(arg1.parts) > 1:
model = Identifier(parts=arg1.parts[:-1])

if (
Expand All @@ -440,40 +440,43 @@ def split_filters(node, **kwargs):
return
table_filters.append(node)

query_traversal(int_select.where, split_filters)
# find subselects
main_integration, _ = self.resolve_database_table(table)
find_selects = self.get_nested_selects_plan_fnc(main_integration, force=True)
query_traversal(int_select.where, find_selects)

def filters_to_bin_op(filters):
# make a new where clause without params
where = None
for flt in filters:
if where is None:
where = flt
else:
where = BinaryOperation(op='and', args=[where, flt])
return where
# split conditions
query_traversal(int_select.where, split_filters)

model_where = None
if len(model_filters) > 0 and 'or' not in binary_ops:
int_select.where = filters_to_bin_op(table_filters)
model_where = filters_to_bin_op(model_filters)

integration_select_step = self.plan_integration_select(int_select)

predictor_identifier = utils.get_predictor_name_identifier(predictor)

if len(params) == 0:
params = None

row_dict = None
if model_filters:
row_dict = {}
for el in model_filters:
if isinstance(el.args[0], Identifier) and el.op == '=':
if isinstance(el.args[1], (Constant, Parameter)):
row_dict[el.args[0].parts[-1]] = el.args[1].value

last_step = self.plan.add_step(ApplyPredictorStep(
namespace=predictor_namespace,
dataframe=integration_select_step.result,
predictor=predictor_identifier,
params=params
params=params,
row_dict=row_dict
))

return {
'predictor': last_step,
'data': integration_select_step,
'model_filters': model_where,
}

def plan_fetch_timeseries_partitions(self, query, table, predictor_group_by_names):
Expand Down Expand Up @@ -795,7 +798,7 @@ def _check_condition(node, **kwargs):
if not isinstance(arg1, Identifier):
arg1, arg2 = arg2, arg1

if isinstance(arg1, Identifier) and isinstance(arg2, Constant):
if isinstance(arg1, Identifier) and isinstance(arg2, (Constant, Parameter)):
if len(arg1.parts) < 2:
return

Expand All @@ -810,6 +813,9 @@ def _check_condition(node, **kwargs):
node2._orig_node = node
tables_idx[parts]['conditions'].append(node2)

find_selects = self.get_nested_selects_plan_fnc(self.default_namespace, force=True)
query_traversal(query.where, find_selects)

query_traversal(query.where, _check_condition)

# create plan
Expand All @@ -826,12 +832,7 @@ def _check_condition(node, **kwargs):
item['sub_select'].parentheses = False
step = self.plan_select(item['sub_select'])

where = None
for cond in item['conditions']:
if where is None:
where = cond
else:
where = BinaryOperation(op='and', args=[where, cond])
where = filters_to_bin_op(item['conditions'])

# apply table alias
query2 = Select(targets=[Star()], where=where)
Expand All @@ -857,8 +858,10 @@ def _check_condition(node, **kwargs):
if item['conditions']:
row_dict = {}
for el in item['conditions']:
if isinstance(el.args[0], Identifier) and isinstance(el.args[1], Constant) and el.op == '=':
row_dict[el.args[0].parts[-1]] = el.args[1].value
if isinstance(el.args[0], Identifier) and el.op == '=':

if isinstance(el.args[1], (Constant, Parameter)):
row_dict[el.args[0].parts[-1]] = el.args[1].value

# exclude condition
item['conditions'][0]._orig_node.args = [Constant(0), Constant(0)]
Expand All @@ -873,7 +876,6 @@ def _check_condition(node, **kwargs):
step_stack.append(predictor_step)
else:
# is table

query2 = Select(from_table=table_name, targets=[Star()])
# parts = tuple(map(str.lower, table_name.parts))
conditions = item['conditions']
Expand All @@ -887,6 +889,7 @@ def _check_condition(node, **kwargs):
else:
query2.where = cond

# TODO use self.get_integration_select_step(query2)
step = FetchDataframeStep(integration=item['integration'], query=query2)
self.plan.add_step(step)
step_stack.append(step)
Expand All @@ -899,6 +902,7 @@ def _check_condition(node, **kwargs):
# TODO
new_join.left = Identifier('tab1')
new_join.right = Identifier('tab2')
new_join.implicit = False

step = self.plan.add_step(JoinStep(left=step_left.result, right=step_right.result, query=new_join))

Expand Down Expand Up @@ -1042,26 +1046,21 @@ def plan_join(self, query, integration=None):

aliased_fields = self.get_aliased_fields(query.targets)

recursively_check_join_identifiers_for_ambiguity(query.where)
recursively_check_join_identifiers_for_ambiguity(query.group_by, aliased_fields=aliased_fields)
recursively_check_join_identifiers_for_ambiguity(query.having)
recursively_check_join_identifiers_for_ambiguity(query.order_by, aliased_fields=aliased_fields)

# check predictor
predictor = None
table = None
predictor_namespace = None
predictor_is_left = False

if not (isinstance(join_right, Identifier) and self.is_predictor(join_right)):
if not self.is_predictor(join_right):
# predictor not in the right, swap
join_left, join_right = join_right, join_left
predictor_is_left = True

if isinstance(join_right, Identifier) and self.is_predictor(join_right):
if self.is_predictor(join_right):
# predictor is in the right now

if isinstance(join_left, Identifier) and self.is_predictor(join_left):
if self.is_predictor(join_left):
# left is predictor too

raise PlanningException(f'Can\'t join two predictors {str(join_left.parts[0])} and {str(join_left.parts[1])}')
Expand All @@ -1077,6 +1076,11 @@ def plan_join(self, query, integration=None):
# Apply mindsdb model to result of last dataframe fetch
# Then join results of applying mindsdb with table

recursively_check_join_identifiers_for_ambiguity(query.where)
recursively_check_join_identifiers_for_ambiguity(query.group_by, aliased_fields=aliased_fields)
recursively_check_join_identifiers_for_ambiguity(query.having)
recursively_check_join_identifiers_for_ambiguity(query.order_by, aliased_fields=aliased_fields)

if self.get_predictor(predictor).get('timeseries'):
predictor_steps = self.plan_timeseries_predictor(query, table, predictor_namespace, predictor)
else:
Expand All @@ -1101,10 +1105,6 @@ def plan_join(self, query, integration=None):

last_step = self.plan.add_step(JoinStep(left=left, right=right, query=new_join))

if predictor_steps.get('model_filters'):
last_step = self.plan.add_step(FilterStep(dataframe=last_step.result,
query=predictor_steps['model_filters']))

# limit from timeseries
if predictor_steps.get('saved_limit'):
last_step = self.plan.add_step(LimitOffsetStep(dataframe=last_step.result,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def plan_delete(self, query: Delete):
is_api_db = self.integrations.get(main_integration, {}).get('class_type') == 'api'

find_selects = self.get_nested_selects_plan_fnc(main_integration, force=is_api_db)
utils.query_traversal(query.where, find_selects)
query_traversal(query.where, find_selects)

self.prepare_integration_select(main_integration, query.where)

Expand Down
39 changes: 25 additions & 14 deletions mindsdb_sql/planner/utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import copy
from typing import List

from mindsdb_sql.exceptions import PlanningException
from mindsdb_sql.parser.ast import (Identifier, Operation, Star, Select, BinaryOperation, Constant,
OrderBy, UnaryOperation, NullConstant, TypeCast, Parameter)
from mindsdb_sql.parser import ast


def get_integration_path_from_identifier(identifier):
parts = identifier.parts
integration_name = parts[0]
new_parts = parts[1:]

if len(parts) == 1:
raise PlanningException(f'No integration specified for table: {str(identifier)}')
elif len(parts) > 4:
raise PlanningException(f'Too many parts (dots) in table identifier: {str(identifier)}')

new_identifier = copy.deepcopy(identifier)
new_identifier.parts = new_parts

return integration_name, new_identifier
# def get_integration_path_from_identifier(identifier):
# parts = identifier.parts
# integration_name = parts[0]
# new_parts = parts[1:]
#
# if len(parts) == 1:
# raise PlanningException(f'No integration specified for table: {str(identifier)}')
# elif len(parts) > 4:
# raise PlanningException(f'Too many parts (dots) in table identifier: {str(identifier)}')
#
# new_identifier = copy.deepcopy(identifier)
# new_identifier.parts = new_parts
#
# return integration_name, new_identifier


def get_predictor_name_identifier(identifier):
Expand Down Expand Up @@ -354,3 +355,13 @@ def params_replace(node, **kwargs):

return query


def filters_to_bin_op(filters: List[BinaryOperation]):
# make a new where clause without params
where = None
for flt in filters:
if where is None:
where = flt
else:
where = BinaryOperation(op='and', args=[where, flt])
return where
Loading

0 comments on commit 6512e8c

Please sign in to comment.