diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index 6f0f56b..1bb3afa 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -21,7 +21,7 @@ class TableInfo: sub_select: ast.ASTNode = None predictor_info: dict = None join_condition = None - + index: int = None class PlanJoin: @@ -85,12 +85,15 @@ def __init__(self, planner): # index to lookup tables self.tables_idx = None + self.tables = [] + self.tables_fetch_step = {} self.step_stack = None self.query_context = {} self.partition = None + def plan(self, query): self.tables_idx = {} join_step = self.plan_join_tables(query) @@ -146,7 +149,8 @@ def resolve_table(self, table): return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select) def get_table_for_column(self, column: Identifier): - + if not isinstance(column, Identifier): + return # to lowercase parts = tuple(map(str.lower, column.parts[:-1])) if parts in self.tables_idx: @@ -161,6 +165,9 @@ def get_join_sequence(self, node, condition=None): for alias in table_info.aliases: self.tables_idx[alias] = table_info + table_info.index = len(self.tables) + self.tables.append(table_info) + table_info.predictor_info = self.planner.get_predictor(node) if condition is not None: @@ -378,13 +385,16 @@ def process_table(self, item, query_in): # not use conditions conditions = [] + conditions += self.get_filters_from_join_conditions(item) + if self.query_context['use_limit']: order_by = None if query_in.order_by is not None: order_by = [] # all order column be from this table for col in query_in.order_by: - if self.get_table_for_column(col.field).table != item.table: + table_info = self.get_table_for_column(col.field) + if table_info is None or table_info.table != item.table: order_by = False break col = copy.deepcopy(col) @@ -408,6 +418,8 @@ def process_table(self, item, query_in): query2.where = cond step = self.planner.get_integration_select_step(query2) + self.tables_fetch_step[item.index] = step + self.add_plan_step(step) self.step_stack.append(step) @@ -442,6 +454,70 @@ def _check_conditions(node, **kwargs): query_traversal(model_table.join_condition, _check_conditions) return columns_map + def get_filters_from_join_conditions(self, fetch_table): + + binary_ops = set() + conditions = [] + data_conditions = [] + + def _check_conditions(node, **kwargs): + if not isinstance(node, BinaryOperation): + return + + if node.op != '=': + binary_ops.add(node.op.lower()) + return + + arg1, arg2 = node.args + table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None + table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None + + if table1 is not fetch_table: + if table2 is not fetch_table: + return + # set our table first + table1, table2 = table2, table1 + arg1, arg2 = arg2, arg1 + + if isinstance(arg2, Constant): + conditions.append(node) + elif table2 is not None: + data_conditions.append([arg1, arg2]) + + query_traversal(fetch_table.join_condition, _check_conditions) + + binary_ops.discard('and') + if len(binary_ops) > 0: + # other operations exists, skip + return [] + + for arg1, arg2 in data_conditions: + # is fetched? + table2 = self.get_table_for_column(arg2) + fetch_step = self.tables_fetch_step.get(table2.index) + + if fetch_step is None: + continue + + # extract distinct values + # remove aliases + arg1 = Identifier(parts=[arg1.parts[-1]]) + arg2 = Identifier(parts=[arg2.parts[-1]]) + + query2 = Select(targets=[arg2], distinct=True) + subselect_step = SubSelectStep(query2, fetch_step.result) + subselect_step = self.add_plan_step(subselect_step) + + conditions.append(BinaryOperation( + op='in', + args=[ + arg1, + Parameter(subselect_step.result) + ] + )) + + return conditions + def process_predictor(self, item, query_in): if len(self.step_stack) == 0: raise NotImplementedError("Predictor can't be first element of join syntax")