Skip to content

Commit

Permalink
Merge pull request #305 from mindsdb/model-filter
Browse files Browse the repository at this point in the history
Filter for model in 'table join model'
  • Loading branch information
ea-rus authored Sep 21, 2023
2 parents 8619961 + 6eeade5 commit a1c3d14
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 50 deletions.
54 changes: 31 additions & 23 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,23 +407,25 @@ def plan_predictor(self, query, table, predictor_namespace, predictor):
params = query.using

binary_ops = []
filters = []
table_filters = []
model_filters = []

def split_filters(node, **kwargs):
# split conditions between model and table

def extract_predictor_params(node, **kwargs):
if isinstance(node, BinaryOperation):
op = node.op.lower()

binary_ops.append(op)

if op != '=':
if op in ['and', 'or']:
return

arg1, arg2 = node.args
if not isinstance(arg1, Identifier):
arg1, arg2 = arg2, arg1

if isinstance(arg1, Identifier) and isinstance(arg2, Constant) and len(arg1.parts) > 1:
col = arg1.parts[-1]
model = Identifier(parts=arg1.parts[:-1])

if (
Expand All @@ -432,42 +434,44 @@ def extract_predictor_params(node, **kwargs):
len(model.parts) == 1 and model.parts[0] == predictor_alias
)
):
params[col] = arg2.value
model_filters.append(node)
return
filters.append(node)
table_filters.append(node)

query_traversal(int_select.where, extract_predictor_params)
query_traversal(int_select.where, split_filters)

if len(params) > 0:
if 'or' in binary_ops:
# rollback
params = {}
else:
# make a new where clause without params
where = None
for flt in filters:
if where is None:
where = flt
else:
where = BinaryOperation(op='and', args=[where, flt])
int_select.where = where
def filters_to_bin_op(filters):
# make a new where clause without params
where = None
for flt in filters:
if where is None:
where = flt
else:
where = BinaryOperation(op='and', args=[where, flt])
return where

model_where = None
if len(model_filters) > 0 and 'or' not in binary_ops:
int_select.where = filters_to_bin_op(table_filters)
model_where = filters_to_bin_op(model_filters)

integration_select_step = self.plan_integration_select(int_select)

predictor_identifier = utils.get_predictor_name_identifier(predictor)

if len(params) == 0:
params = None
predictor_step = self.plan.add_step(ApplyPredictorStep(
last_step = self.plan.add_step(ApplyPredictorStep(
namespace=predictor_namespace,
dataframe=integration_select_step.result,
predictor=predictor_identifier,
params=params
))

return {
'predictor': predictor_step,
'data': integration_select_step
'predictor': last_step,
'data': integration_select_step,
'model_filters': model_where,
}

def plan_fetch_timeseries_partitions(self, query, table, predictor_group_by_names):
Expand Down Expand Up @@ -1080,6 +1084,10 @@ def plan_join(self, query, integration=None):

last_step = self.plan.add_step(JoinStep(left=left, right=right, query=new_join))

if predictor_steps.get('model_filters'):
last_step = self.plan.add_step(FilterStep(dataframe=last_step.result,
query=predictor_steps['model_filters']))

# limit from timeseries
if predictor_steps.get('saved_limit'):
last_step = self.plan.add_step(LimitOffsetStep(dataframe=last_step.result,
Expand Down
64 changes: 37 additions & 27 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mindsdb_sql.planner import plan_query
from mindsdb_sql.planner.query_plan import QueryPlan
from mindsdb_sql.planner.step_result import Result
from mindsdb_sql.planner.steps import (FetchDataframeStep, ProjectStep, JoinStep, ApplyPredictorStep,
from mindsdb_sql.planner.steps import (FetchDataframeStep, ProjectStep, JoinStep, ApplyPredictorStep, FilterStep,
LimitOffsetStep, GroupByStep, SubSelectStep, ApplyPredictorRowStep)
from mindsdb_sql.parser.utils import JoinType
from mindsdb_sql import parse_sql
Expand Down Expand Up @@ -131,30 +131,30 @@ def test_join_predictor_plan_where(self):
assert plan.steps == expected_plan.steps


def test_join_predictor_error_when_filtering_on_predictions(self):
"""
Query:
SELECT rental_price_confidence
FROM postgres_90.test_data.home_rentals AS ta
JOIN mindsdb.hrp3 AS tb
WHERE ta.sqft > 1000 AND tb.rental_price_confidence > 0.5
LIMIT 5;
"""

query = Select(targets=[Identifier('rental_price_confidence')],
from_table=Join(left=Identifier('postgres_90.test_data.home_rentals', alias=Identifier('ta')),
right=Identifier('mindsdb.hrp3', alias=Identifier('tb')),
join_type=JoinType.INNER_JOIN,
implicit=True),
where=BinaryOperation('and', args=[
BinaryOperation('>', args=[Identifier('ta.sqft'), Constant(1000)]),
BinaryOperation('>', args=[Identifier('tb.rental_price_confidence'), Constant(0.5)]),
]),
limit=5
)

with pytest.raises(PlanningException):
plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}})
# def test_join_predictor_error_when_filtering_on_predictions(self):
# """
# Query:
# SELECT rental_price_confidence
# FROM postgres_90.test_data.home_rentals AS ta
# JOIN mindsdb.hrp3 AS tb
# WHERE ta.sqft > 1000 AND tb.rental_price_confidence > 0.5
# LIMIT 5;
# """
#
# query = Select(targets=[Identifier('rental_price_confidence')],
# from_table=Join(left=Identifier('postgres_90.test_data.home_rentals', alias=Identifier('ta')),
# right=Identifier('mindsdb.hrp3', alias=Identifier('tb')),
# join_type=JoinType.INNER_JOIN,
# implicit=True),
# where=BinaryOperation('and', args=[
# BinaryOperation('>', args=[Identifier('ta.sqft'), Constant(1000)]),
# BinaryOperation('>', args=[Identifier('tb.rental_price_confidence'), Constant(0.5)]),
# ]),
# limit=5
# )
#
# with pytest.raises(PlanningException):
# plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}})

def test_join_predictor_plan_group_by(self):
query = Select(targets=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')],
Expand Down Expand Up @@ -646,12 +646,22 @@ def test_where_using(self):
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as a where a.x=1 and a.y=3', dialect='mindsdb')),
ApplyPredictorStep(namespace='proj', dataframe=Result(0),
predictor=Identifier('pred.1', alias=Identifier('p')), params={'x': 1, 'y': ''}),
predictor=Identifier('pred.1', alias=Identifier('p'))),
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('result_0'),
right=Identifier('result_1'),
join_type=JoinType.JOIN)),
ProjectStep(dataframe=Result(2), columns=[Star()]),
FilterStep(dataframe=Result(2), query=BinaryOperation(op='and', args=[
BinaryOperation(op='=', args=[
Identifier(parts=['p', 'x']),
Constant(1)
]),
BinaryOperation(op='=', args=[
Identifier(parts=['p', 'y']),
Constant('')
]),
])),
ProjectStep(dataframe=Result(3), columns=[Star()]),
],
)

Expand Down

0 comments on commit a1c3d14

Please sign in to comment.