diff --git a/mindsdb_sql/__about__.py b/mindsdb_sql/__about__.py index 461e355f..e6aef77f 100644 --- a/mindsdb_sql/__about__.py +++ b/mindsdb_sql/__about__.py @@ -1,6 +1,6 @@ __title__ = 'mindsdb_sql' __package_name__ = 'mindsdb_sql' -__version__ = '0.8.1' +__version__ = '0.9.0' __description__ = "Pure python SQL parser" __email__ = "jorge@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/mindsdb_sql/parser/dialects/mindsdb/latest.py b/mindsdb_sql/parser/dialects/mindsdb/latest.py index 5a755df8..70b4beae 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/latest.py +++ b/mindsdb_sql/parser/dialects/mindsdb/latest.py @@ -6,8 +6,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, alias=None, parentheses=False, **kwargs) def to_tree(self, *args, level=0, **kwargs): - return '\t'*level + 'Latest()' + return '\t'*level + 'Latest()' def get_string(self, *args, **kwargs): return 'LATEST' - diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index 93d30e22..df5201d7 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -576,7 +576,7 @@ def add_order_not_null(condition): ) integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter) integration_selects = [integration_select] - elif isinstance(time_filter, BinaryOperation) and time_filter.op == '=' and time_filter.args[1] == Latest(): + elif isinstance(time_filter, BinaryOperation) and time_filter.op == '=': integration_select = Select(targets=[Star()], from_table=table, where=preparation_where, @@ -584,7 +584,25 @@ def add_order_not_null(condition): order_by=order_by, limit=Constant(predictor_window), ) - integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter) + + if type(time_filter.args[1]) is Latest: + integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter) + else: + time_filter_date = time_filter.args[1] + preparation_time_filter = BinaryOperation( + '<=', + args=[ + Identifier(predictor_time_column_name), + time_filter_date + ] + ) + integration_select.where = add_order_not_null( + replace_time_filter( + preparation_where2, time_filter, preparation_time_filter + ) + ) + time_filter.op = '>' + integration_selects = [integration_select] elif isinstance(time_filter, BinaryOperation) and time_filter.op in ('>', '>='): time_filter_date = time_filter.args[1] diff --git a/tests/test_planner/test_ts_predictor.py b/tests/test_planner/test_ts_predictor.py index 49b26a80..1f602d89 100644 --- a/tests/test_planner/test_ts_predictor.py +++ b/tests/test_planner/test_ts_predictor.py @@ -4,12 +4,23 @@ from mindsdb_sql import parse_sql from mindsdb_sql.exceptions import PlanningException -from mindsdb_sql.parser.ast import * +from mindsdb_sql.parser.ast import Select, Star, Identifier, Join, Constant, BinaryOperation, Update, BetweenOperation from mindsdb_sql.parser.dialects.mindsdb.latest import Latest from mindsdb_sql.planner import plan_query from mindsdb_sql.planner.query_plan import QueryPlan from mindsdb_sql.planner.step_result import Result -from mindsdb_sql.planner.steps import * +from mindsdb_sql.planner.steps import ( + JoinStep, + SaveToTable, + ProjectStep, + InsertToTable, + MapReduceStep, + MultipleSteps, + UpdateToTable, + LimitOffsetStep, + FetchDataframeStep, + ApplyTimeseriesPredictorStep +) from mindsdb_sql.parser.utils import JoinType @@ -725,7 +736,74 @@ def test_join_predictor_timeseries_concrete_date_less_or_equal(self): for i in range(len(plan.steps)): assert plan.steps[i] == expected_plan.steps[i] - + + def test_join_predictor_timeseries_concrete_date_equal(self): + predictor_window = 10 + group_by_column = 'vendor_id' + + sql = """ + select * from + mysql.data.ny_output as ta + join mindsdb.tp3 as tb + where + ta.pickup_hour = 10 + and ta.vendor_id = 1 + """ + + query = parse_sql(sql, dialect='mindsdb') + + expected_plan = QueryPlan( + steps=[ + FetchDataframeStep(integration='mysql', + query=Select(targets=[ + Identifier(parts=['ta', group_by_column], alias=Identifier(group_by_column))], + from_table=Identifier('data.ny_output', alias=Identifier('ta')), + where=BinaryOperation('=', args=[Identifier('ta.vendor_id'), Constant(1)]), + distinct=True, + ) + ), + MapReduceStep( + values=Result(0), + reduce='union', + step=FetchDataframeStep( + integration='mysql', + query=parse_sql(""" + SELECT * FROM data.ny_output AS ta + WHERE ta.pickup_hour <= 10 AND ta.vendor_id = 1 AND ta.pickup_hour is not null and + ta.vendor_id = '$var[vendor_id]' + ORDER BY ta.pickup_hour DESC LIMIT 10 + """), + ), + ), + ApplyTimeseriesPredictorStep( + output_time_filter=BinaryOperation('>', args=[Identifier('ta.pickup_hour'), Constant(10)]), + namespace='mindsdb', + predictor=Identifier('tp3', alias=Identifier('tb')), + dataframe=Result(1), + ), + JoinStep(left=Result(1), + right=Result(2), + query=Join( + right=Identifier('result_2'), + left=Identifier('result_1'), + join_type=JoinType.JOIN) + ), + ProjectStep(dataframe=Result(3), columns=[Star()]), + ], + ) + + plan = plan_query(query, + integrations=['mysql'], + predictor_namespace='mindsdb', + predictor_metadata={ + 'tp3': {'timeseries': True, + 'order_by_column': 'pickup_hour', + 'group_by_columns': [group_by_column], + 'window': predictor_window} + }) + + for i in range(len(plan.steps)): + assert plan.steps[i] == expected_plan.steps[i] def test_join_predictor_timeseries_error_on_nested_where(self): query = Select(targets=[Identifier('pred.time'), Identifier('pred.price')],