Skip to content

Commit

Permalink
Merge branch 'optimize-join' into cte-support
Browse files Browse the repository at this point in the history
# Conflicts:
#	mindsdb_sql/planner/plan_join.py
  • Loading branch information
ea-rus committed Nov 2, 2024
2 parents c5cb9df + ef30b12 commit e324f45
Showing 1 changed file with 79 additions and 3 deletions.
82 changes: 79 additions & 3 deletions mindsdb_sql/planner/plan_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TableInfo:
sub_select: ast.ASTNode = None
predictor_info: dict = None
join_condition = None

index: int = None

class PlanJoin:

Expand Down Expand Up @@ -85,12 +85,15 @@ def __init__(self, planner):

# index to lookup tables
self.tables_idx = None
self.tables = []
self.tables_fetch_step = {}

self.step_stack = None
self.query_context = {}

self.partition = None


def plan(self, query):
self.tables_idx = {}
join_step = self.plan_join_tables(query)
Expand Down Expand Up @@ -146,7 +149,8 @@ def resolve_table(self, table):
return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select)

def get_table_for_column(self, column: Identifier):

if not isinstance(column, Identifier):
return
# to lowercase
parts = tuple(map(str.lower, column.parts[:-1]))
if parts in self.tables_idx:
Expand All @@ -161,6 +165,9 @@ def get_join_sequence(self, node, condition=None):
for alias in table_info.aliases:
self.tables_idx[alias] = table_info

table_info.index = len(self.tables)
self.tables.append(table_info)

table_info.predictor_info = self.planner.get_predictor(node)

if condition is not None:
Expand Down Expand Up @@ -378,13 +385,16 @@ def process_table(self, item, query_in):
# not use conditions
conditions = []

conditions += self.get_filters_from_join_conditions(item)

if self.query_context['use_limit']:
order_by = None
if query_in.order_by is not None:
order_by = []
# all order column be from this table
for col in query_in.order_by:
if self.get_table_for_column(col.field).table != item.table:
table_info = self.get_table_for_column(col.field)
if table_info is None or table_info.table != item.table:
order_by = False
break
col = copy.deepcopy(col)
Expand All @@ -408,6 +418,8 @@ def process_table(self, item, query_in):
query2.where = cond

step = self.planner.get_integration_select_step(query2)
self.tables_fetch_step[item.index] = step

self.add_plan_step(step)
self.step_stack.append(step)

Expand Down Expand Up @@ -442,6 +454,70 @@ def _check_conditions(node, **kwargs):
query_traversal(model_table.join_condition, _check_conditions)
return columns_map

def get_filters_from_join_conditions(self, fetch_table):

binary_ops = set()
conditions = []
data_conditions = []

def _check_conditions(node, **kwargs):
if not isinstance(node, BinaryOperation):
return

if node.op != '=':
binary_ops.add(node.op.lower())
return

arg1, arg2 = node.args
table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None
table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None

if table1 is not fetch_table:
if table2 is not fetch_table:
return
# set our table first
table1, table2 = table2, table1
arg1, arg2 = arg2, arg1

if isinstance(arg2, Constant):
conditions.append(node)
elif table2 is not None:
data_conditions.append([arg1, arg2])

query_traversal(fetch_table.join_condition, _check_conditions)

binary_ops.discard('and')
if len(binary_ops) > 0:
# other operations exists, skip
return []

for arg1, arg2 in data_conditions:
# is fetched?
table2 = self.get_table_for_column(arg2)
fetch_step = self.tables_fetch_step.get(table2.index)

if fetch_step is None:
continue

# extract distinct values
# remove aliases
arg1 = Identifier(parts=[arg1.parts[-1]])
arg2 = Identifier(parts=[arg2.parts[-1]])

query2 = Select(targets=[arg2], distinct=True)
subselect_step = SubSelectStep(query2, fetch_step.result)
subselect_step = self.add_plan_step(subselect_step)

conditions.append(BinaryOperation(
op='in',
args=[
arg1,
Parameter(subselect_step.result)
]
))

return conditions

def process_predictor(self, item, query_in):
if len(self.step_stack) == 0:
raise NotImplementedError("Predictor can't be first element of join syntax")
Expand Down

0 comments on commit e324f45

Please sign in to comment.