diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index 3089dc6..ea86784 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -21,7 +21,7 @@ class TableInfo: sub_select: ast.ASTNode = None predictor_info: dict = None join_condition = None - + index: int = None class PlanJoin: @@ -85,12 +85,15 @@ def __init__(self, planner): # index to lookup tables self.tables_idx = None + self.tables = [] + self.tables_fetch_step = {} self.step_stack = None self.query_context = {} self.partition = None + def plan(self, query): self.tables_idx = {} join_step = self.plan_join_tables(query) @@ -145,7 +148,8 @@ def resolve_table(self, table): return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select) def get_table_for_column(self, column: Identifier): - + if not isinstance(column, Identifier): + return # to lowercase parts = tuple(map(str.lower, column.parts[:-1])) if parts in self.tables_idx: @@ -160,6 +164,9 @@ def get_join_sequence(self, node, condition=None): for alias in table_info.aliases: self.tables_idx[alias] = table_info + table_info.index = len(self.tables) + self.tables.append(table_info) + table_info.predictor_info = self.planner.get_predictor(node) if condition is not None: @@ -375,13 +382,16 @@ def process_table(self, item, query_in): # not use conditions conditions = [] + conditions += self.get_filters_from_join_conditions(item) + if self.query_context['use_limit']: order_by = None if query_in.order_by is not None: order_by = [] # all order column be from this table for col in query_in.order_by: - if self.get_table_for_column(col.field).table != item.table: + table_info = self.get_table_for_column(col.field) + if table_info is None or table_info.table != item.table: order_by = False break col = copy.deepcopy(col) @@ -406,6 +416,8 @@ def process_table(self, item, query_in): # step = self.planner.get_integration_select_step(query2) step = FetchDataframeStep(integration=item.integration, query=query2) + self.tables_fetch_step[item.index] = step + self.add_plan_step(step) self.step_stack.append(step) @@ -440,6 +452,70 @@ def _check_conditions(node, **kwargs): query_traversal(model_table.join_condition, _check_conditions) return columns_map + def get_filters_from_join_conditions(self, fetch_table): + + binary_ops = set() + conditions = [] + data_conditions = [] + + def _check_conditions(node, **kwargs): + if not isinstance(node, BinaryOperation): + return + + if node.op != '=': + binary_ops.add(node.op.lower()) + return + + arg1, arg2 = node.args + table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None + table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None + + if table1 is not fetch_table: + if table2 is not fetch_table: + return + # set our table first + table1, table2 = table2, table1 + arg1, arg2 = arg2, arg1 + + if isinstance(arg2, Constant): + conditions.append(node) + elif table2 is not None: + data_conditions.append([arg1, arg2]) + + query_traversal(fetch_table.join_condition, _check_conditions) + + binary_ops.discard('and') + if len(binary_ops) > 0: + # other operations exists, skip + return [] + + for arg1, arg2 in data_conditions: + # is fetched? + table2 = self.get_table_for_column(arg2) + fetch_step = self.tables_fetch_step.get(table2.index) + + if fetch_step is None: + continue + + # extract distinct values + # remove aliases + arg1 = Identifier(parts=[arg1.parts[-1]]) + arg2 = Identifier(parts=[arg2.parts[-1]]) + + query2 = Select(targets=[arg2], distinct=True) + subselect_step = SubSelectStep(query2, fetch_step.result) + subselect_step = self.add_plan_step(subselect_step) + + conditions.append(BinaryOperation( + op='in', + args=[ + arg1, + Parameter(subselect_step.result) + ] + )) + + return conditions + def process_predictor(self, item, query_in): if len(self.step_stack) == 0: raise NotImplementedError("Predictor can't be first element of join syntax") diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index 40c792a..8dbab08 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -22,8 +22,6 @@ def test_join_predictor_plan(self): """ query = parse_sql(sql) - query_step = parse_sql("select tab1.column1, pred.predicted") - query_step.from_table = Parameter(Result(2)) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', @@ -75,7 +73,7 @@ def test_join_predictor_plan_aliases(self): plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) assert plan.steps == expected_plan.steps - + def test_join_predictor_plan_limit(self): @@ -116,7 +114,7 @@ def test_join_predictor_plan_limit(self): plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) assert plan.steps == expected_plan.steps - + # def test_join_predictor_error_when_filtering_on_predictions(self): # """ @@ -673,15 +671,16 @@ def test_complex_subselect(self): sql = ''' select t2.x, m.id, (select a from int.tab0 where x=0) from int.tab1 t1 - join int.tab2 t2 on t1.x = t2.x + join int.tab2 t2 on t1.x = t2.a join mindsdb.pred m where m.a=(select a from int.tab3 where x=3) and t2.x=(select a from int.tab4 where x=4) and t1.b=1 and t2.b=2 and t1.a = t2.a ''' - q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 ') - q_table2.where.args[0].args[1] = Parameter(Result(2)) + q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 AND a IN 1') + q_table2.where.args[0].args[0].args[1] = Parameter(Result(2)) + q_table2.where.args[1].args[1] = Parameter(Result(4)) subquery = parse_sql(""" select t2.x, m.id, x @@ -708,22 +707,23 @@ def test_complex_subselect(self): # tables FetchDataframeStep(integration='int', query=parse_sql('select * from tab1 as t1 where b=1')), + SubSelectStep(dataframe=Result(3), query=Select(targets=[Identifier('x')], distinct=True)), FetchDataframeStep(integration='int', query=q_table2), - JoinStep(left=Result(3), right=Result(4), + JoinStep(left=Result(3), right=Result(5), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN, - condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.x')]) + condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.a')]) ) ), # model - ApplyPredictorStep(namespace='mindsdb', dataframe=Result(5), + ApplyPredictorStep(namespace='mindsdb', dataframe=Result(6), predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': Result(1)}), - JoinStep(left=Result(5), right=Result(6), + JoinStep(left=Result(6), right=Result(7), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - QueryStep(subquery, from_table=Result(7)), + QueryStep(subquery, from_table=Result(8)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) diff --git a/tests/test_planner/test_join_tables.py b/tests/test_planner/test_join_tables.py index f44483c..f624fd9 100644 --- a/tests/test_planner/test_join_tables.py +++ b/tests/test_planner/test_join_tables.py @@ -16,7 +16,7 @@ def test_join_tables_plan(self): query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ) ) @@ -35,7 +35,7 @@ def test_join_tables_plan(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -45,13 +45,13 @@ def test_join_tables_plan(self): ) assert plan.steps == expected_plan.steps - + def test_join_tables_where_plan(self): query = parse_sql(''' SELECT tab1.column1, tab2.column1, tab2.column2 FROM int.tab1 - INNER JOIN int2.tab2 ON tab1.column1 = tab2.column1 + INNER JOIN int2.tab2 ON tab1.column1 > tab2.column1 WHERE ((tab1.column1 = 1) AND (tab2.column1 = 0)) AND (tab1.column3 = tab2.column3) @@ -71,7 +71,7 @@ def test_join_tables_where_plan(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -90,7 +90,7 @@ def test_join_tables_plan_groupby(self): Function('sum', args=[Identifier('tab2.column2')], alias=Identifier('total'))], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ), @@ -117,7 +117,7 @@ def test_join_tables_plan_groupby(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -126,13 +126,13 @@ def test_join_tables_plan_groupby(self): ], ) assert plan.steps == expected_plan.steps - + def test_join_tables_plan_limit_offset(self): query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ), limit=Constant(10), @@ -161,7 +161,7 @@ def test_join_tables_plan_limit_offset(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -171,13 +171,13 @@ def test_join_tables_plan_limit_offset(self): ) assert plan.steps == expected_plan.steps - + def test_join_tables_plan_order_by(self): query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ), limit=Constant(10), @@ -203,7 +203,7 @@ def test_join_tables_plan_order_by(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -213,7 +213,7 @@ def test_join_tables_plan_order_by(self): ) assert plan.steps == expected_plan.steps - + # This quiery should be sent to integration without raising exception # def test_join_tables_where_ambigous_column_error(self): @@ -278,7 +278,7 @@ def test_join_tables_disambiguate_identifiers_in_condition(self): for i in range(len(plan.steps)): assert plan.steps[i] == expected_plan.steps[i] - + def _disabled_test_join_tables_error_on_unspecified_table_in_condition(self): # disabled: identifier can be environment of system variable @@ -328,7 +328,7 @@ def test_join_tables_plan_default_namespace(self): def test_complex_join_tables(self): query = parse_sql(''' select * from int1.tbl1 t1 - right join int2.tbl2 t2 on t1.id=t2.id + right join int2.tbl2 t2 on t1.id>t2.id join pred m left join tbl3 on tbl3.id=t1.id where t1.a=1 and t2.b=2 and 1=1 @@ -337,6 +337,9 @@ def test_complex_join_tables(self): subquery = copy.deepcopy(query) subquery.from_table = None + q_table3 = parse_sql('select * from tbl3 where id in 0') + q_table3.where.args[1] = Parameter(Result(5)) + plan = plan_query(query, integrations=['int1', 'int2', 'proj'], default_namespace='proj', predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}]) @@ -349,7 +352,7 @@ def test_complex_join_tables(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), condition=BinaryOperation( - op='=', + op='>', args=[Identifier('t1.id'), Identifier('t2.id')]), join_type=JoinType.RIGHT_JOIN)), @@ -359,9 +362,10 @@ def test_complex_join_tables(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - FetchDataframeStep(integration='proj', query=parse_sql('select * from tbl3')), + SubSelectStep(dataframe=Result(0), query=Select(targets=[Identifier('id')], distinct=True)), + FetchDataframeStep(integration='proj', query=q_table3), JoinStep(left=Result(4), - right=Result(5), + right=Result(6), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), condition=BinaryOperation( @@ -369,7 +373,7 @@ def test_complex_join_tables(self): args=[Identifier('tbl3.id'), Identifier('t1.id')]), join_type=JoinType.LEFT_JOIN)), - QueryStep(subquery, from_table=Result(6)), + QueryStep(subquery, from_table=Result(7)), ] )