Skip to content

Commit

Permalink
Merge pull request #412 from mindsdb/optimize-join
Browse files Browse the repository at this point in the history
Optimize join tables from different databases
  • Loading branch information
ea-rus authored Nov 6, 2024
2 parents c50e06c + 0f5dbb1 commit 5525c2c
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 35 deletions.
82 changes: 79 additions & 3 deletions mindsdb_sql/planner/plan_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TableInfo:
sub_select: ast.ASTNode = None
predictor_info: dict = None
join_condition = None

index: int = None

class PlanJoin:

Expand Down Expand Up @@ -85,12 +85,15 @@ def __init__(self, planner):

# index to lookup tables
self.tables_idx = None
self.tables = []
self.tables_fetch_step = {}

self.step_stack = None
self.query_context = {}

self.partition = None


def plan(self, query):
self.tables_idx = {}
join_step = self.plan_join_tables(query)
Expand Down Expand Up @@ -145,7 +148,8 @@ def resolve_table(self, table):
return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select)

def get_table_for_column(self, column: Identifier):

if not isinstance(column, Identifier):
return
# to lowercase
parts = tuple(map(str.lower, column.parts[:-1]))
if parts in self.tables_idx:
Expand All @@ -160,6 +164,9 @@ def get_join_sequence(self, node, condition=None):
for alias in table_info.aliases:
self.tables_idx[alias] = table_info

table_info.index = len(self.tables)
self.tables.append(table_info)

table_info.predictor_info = self.planner.get_predictor(node)

if condition is not None:
Expand Down Expand Up @@ -375,13 +382,16 @@ def process_table(self, item, query_in):
# not use conditions
conditions = []

conditions += self.get_filters_from_join_conditions(item)

if self.query_context['use_limit']:
order_by = None
if query_in.order_by is not None:
order_by = []
# all order column be from this table
for col in query_in.order_by:
if self.get_table_for_column(col.field).table != item.table:
table_info = self.get_table_for_column(col.field)
if table_info is None or table_info.table != item.table:
order_by = False
break
col = copy.deepcopy(col)
Expand All @@ -406,6 +416,8 @@ def process_table(self, item, query_in):

# step = self.planner.get_integration_select_step(query2)
step = FetchDataframeStep(integration=item.integration, query=query2)
self.tables_fetch_step[item.index] = step

self.add_plan_step(step)
self.step_stack.append(step)

Expand Down Expand Up @@ -440,6 +452,70 @@ def _check_conditions(node, **kwargs):
query_traversal(model_table.join_condition, _check_conditions)
return columns_map

def get_filters_from_join_conditions(self, fetch_table):

binary_ops = set()
conditions = []
data_conditions = []

def _check_conditions(node, **kwargs):
if not isinstance(node, BinaryOperation):
return

if node.op != '=':
binary_ops.add(node.op.lower())
return

arg1, arg2 = node.args
table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None
table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None

if table1 is not fetch_table:
if table2 is not fetch_table:
return
# set our table first
table1, table2 = table2, table1
arg1, arg2 = arg2, arg1

if isinstance(arg2, Constant):
conditions.append(node)
elif table2 is not None:
data_conditions.append([arg1, arg2])

query_traversal(fetch_table.join_condition, _check_conditions)

binary_ops.discard('and')
if len(binary_ops) > 0:
# other operations exists, skip
return []

for arg1, arg2 in data_conditions:
# is fetched?
table2 = self.get_table_for_column(arg2)
fetch_step = self.tables_fetch_step.get(table2.index)

if fetch_step is None:
continue

# extract distinct values
# remove aliases
arg1 = Identifier(parts=[arg1.parts[-1]])
arg2 = Identifier(parts=[arg2.parts[-1]])

query2 = Select(targets=[arg2], distinct=True)
subselect_step = SubSelectStep(query2, fetch_step.result)
subselect_step = self.add_plan_step(subselect_step)

conditions.append(BinaryOperation(
op='in',
args=[
arg1,
Parameter(subselect_step.result)
]
))

return conditions

def process_predictor(self, item, query_in):
if len(self.step_stack) == 0:
raise NotImplementedError("Predictor can't be first element of join syntax")
Expand Down
24 changes: 12 additions & 12 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def test_join_predictor_plan(self):
"""
query = parse_sql(sql)

query_step = parse_sql("select tab1.column1, pred.predicted")
query_step.from_table = Parameter(Result(2))
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
Expand Down Expand Up @@ -75,7 +73,7 @@ def test_join_predictor_plan_aliases(self):
plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}})

assert plan.steps == expected_plan.steps


def test_join_predictor_plan_limit(self):

Expand Down Expand Up @@ -116,7 +114,7 @@ def test_join_predictor_plan_limit(self):
plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}})

assert plan.steps == expected_plan.steps


# def test_join_predictor_error_when_filtering_on_predictions(self):
# """
Expand Down Expand Up @@ -673,15 +671,16 @@ def test_complex_subselect(self):

sql = '''
select t2.x, m.id, (select a from int.tab0 where x=0) from int.tab1 t1
join int.tab2 t2 on t1.x = t2.x
join int.tab2 t2 on t1.x = t2.a
join mindsdb.pred m
where m.a=(select a from int.tab3 where x=3)
and t2.x=(select a from int.tab4 where x=4)
and t1.b=1 and t2.b=2 and t1.a = t2.a
'''

q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 ')
q_table2.where.args[0].args[1] = Parameter(Result(2))
q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 AND a IN 1')
q_table2.where.args[0].args[0].args[1] = Parameter(Result(2))
q_table2.where.args[1].args[1] = Parameter(Result(4))

subquery = parse_sql("""
select t2.x, m.id, x
Expand All @@ -708,22 +707,23 @@ def test_complex_subselect(self):
# tables
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as t1 where b=1')),
SubSelectStep(dataframe=Result(3), query=Select(targets=[Identifier('x')], distinct=True)),
FetchDataframeStep(integration='int', query=q_table2),
JoinStep(left=Result(3), right=Result(4),
JoinStep(left=Result(3), right=Result(5),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN,
condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.x')])
condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.a')])
)
),
# model
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(5),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(6),
predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': Result(1)}),
JoinStep(left=Result(5), right=Result(6),
JoinStep(left=Result(6), right=Result(7),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN)),
QueryStep(subquery, from_table=Result(7)),
QueryStep(subquery, from_table=Result(8)),
],
)
plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}})
Expand Down
44 changes: 24 additions & 20 deletions tests/test_planner/test_join_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_join_tables_plan(self):
query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
)
)
Expand All @@ -35,7 +35,7 @@ def test_join_tables_plan(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -45,13 +45,13 @@ def test_join_tables_plan(self):
)

assert plan.steps == expected_plan.steps


def test_join_tables_where_plan(self):
query = parse_sql('''
SELECT tab1.column1, tab2.column1, tab2.column2
FROM int.tab1
INNER JOIN int2.tab2 ON tab1.column1 = tab2.column1
INNER JOIN int2.tab2 ON tab1.column1 > tab2.column1
WHERE ((tab1.column1 = 1)
AND (tab2.column1 = 0))
AND (tab1.column3 = tab2.column3)
Expand All @@ -71,7 +71,7 @@ def test_join_tables_where_plan(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -90,7 +90,7 @@ def test_join_tables_plan_groupby(self):
Function('sum', args=[Identifier('tab2.column2')], alias=Identifier('total'))],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
),
Expand All @@ -117,7 +117,7 @@ def test_join_tables_plan_groupby(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -126,13 +126,13 @@ def test_join_tables_plan_groupby(self):
],
)
assert plan.steps == expected_plan.steps


def test_join_tables_plan_limit_offset(self):
query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
),
limit=Constant(10),
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_join_tables_plan_limit_offset(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -171,13 +171,13 @@ def test_join_tables_plan_limit_offset(self):
)

assert plan.steps == expected_plan.steps


def test_join_tables_plan_order_by(self):
query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
),
limit=Constant(10),
Expand All @@ -203,7 +203,7 @@ def test_join_tables_plan_order_by(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -213,7 +213,7 @@ def test_join_tables_plan_order_by(self):
)

assert plan.steps == expected_plan.steps


# This quiery should be sent to integration without raising exception
# def test_join_tables_where_ambigous_column_error(self):
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_join_tables_disambiguate_identifiers_in_condition(self):

for i in range(len(plan.steps)):
assert plan.steps[i] == expected_plan.steps[i]


def _disabled_test_join_tables_error_on_unspecified_table_in_condition(self):
# disabled: identifier can be environment of system variable
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_join_tables_plan_default_namespace(self):
def test_complex_join_tables(self):
query = parse_sql('''
select * from int1.tbl1 t1
right join int2.tbl2 t2 on t1.id=t2.id
right join int2.tbl2 t2 on t1.id>t2.id
join pred m
left join tbl3 on tbl3.id=t1.id
where t1.a=1 and t2.b=2 and 1=1
Expand All @@ -337,6 +337,9 @@ def test_complex_join_tables(self):
subquery = copy.deepcopy(query)
subquery.from_table = None

q_table3 = parse_sql('select * from tbl3 where id in 0')
q_table3.where.args[1] = Parameter(Result(5))

plan = plan_query(query, integrations=['int1', 'int2', 'proj'], default_namespace='proj',
predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}])

Expand All @@ -349,7 +352,7 @@ def test_complex_join_tables(self):
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(
op='=',
op='>',
args=[Identifier('t1.id'),
Identifier('t2.id')]),
join_type=JoinType.RIGHT_JOIN)),
Expand All @@ -359,17 +362,18 @@ def test_complex_join_tables(self):
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN)),
FetchDataframeStep(integration='proj', query=parse_sql('select * from tbl3')),
SubSelectStep(dataframe=Result(0), query=Select(targets=[Identifier('id')], distinct=True)),
FetchDataframeStep(integration='proj', query=q_table3),
JoinStep(left=Result(4),
right=Result(5),
right=Result(6),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(
op='=',
args=[Identifier('tbl3.id'),
Identifier('t1.id')]),
join_type=JoinType.LEFT_JOIN)),
QueryStep(subquery, from_table=Result(6)),
QueryStep(subquery, from_table=Result(7)),
]
)

Expand Down

0 comments on commit 5525c2c

Please sign in to comment.