Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize join tables from different databases #412

Merged
merged 5 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions mindsdb_sql/planner/plan_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TableInfo:
sub_select: ast.ASTNode = None
predictor_info: dict = None
join_condition = None

index: int = None

class PlanJoin:

Expand Down Expand Up @@ -85,12 +85,15 @@ def __init__(self, planner):

# index to lookup tables
self.tables_idx = None
self.tables = []
self.tables_fetch_step = {}

self.step_stack = None
self.query_context = {}

self.partition = None


def plan(self, query):
self.tables_idx = {}
join_step = self.plan_join_tables(query)
Expand Down Expand Up @@ -145,7 +148,8 @@ def resolve_table(self, table):
return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select)

def get_table_for_column(self, column: Identifier):

if not isinstance(column, Identifier):
return
# to lowercase
parts = tuple(map(str.lower, column.parts[:-1]))
if parts in self.tables_idx:
Expand All @@ -160,6 +164,9 @@ def get_join_sequence(self, node, condition=None):
for alias in table_info.aliases:
self.tables_idx[alias] = table_info

table_info.index = len(self.tables)
self.tables.append(table_info)

table_info.predictor_info = self.planner.get_predictor(node)

if condition is not None:
Expand Down Expand Up @@ -375,13 +382,16 @@ def process_table(self, item, query_in):
# not use conditions
conditions = []

conditions += self.get_filters_from_join_conditions(item)

if self.query_context['use_limit']:
order_by = None
if query_in.order_by is not None:
order_by = []
# all order column be from this table
for col in query_in.order_by:
if self.get_table_for_column(col.field).table != item.table:
table_info = self.get_table_for_column(col.field)
if table_info is None or table_info.table != item.table:
order_by = False
break
col = copy.deepcopy(col)
Expand All @@ -406,6 +416,8 @@ def process_table(self, item, query_in):

# step = self.planner.get_integration_select_step(query2)
step = FetchDataframeStep(integration=item.integration, query=query2)
self.tables_fetch_step[item.index] = step

self.add_plan_step(step)
self.step_stack.append(step)

Expand Down Expand Up @@ -440,6 +452,70 @@ def _check_conditions(node, **kwargs):
query_traversal(model_table.join_condition, _check_conditions)
return columns_map

def get_filters_from_join_conditions(self, fetch_table):

binary_ops = set()
conditions = []
data_conditions = []

def _check_conditions(node, **kwargs):
if not isinstance(node, BinaryOperation):
return

if node.op != '=':
binary_ops.add(node.op.lower())
return

arg1, arg2 = node.args
table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None
table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None

if table1 is not fetch_table:
if table2 is not fetch_table:
return
# set our table first
table1, table2 = table2, table1
arg1, arg2 = arg2, arg1

if isinstance(arg2, Constant):
conditions.append(node)
elif table2 is not None:
data_conditions.append([arg1, arg2])

query_traversal(fetch_table.join_condition, _check_conditions)

binary_ops.discard('and')
if len(binary_ops) > 0:
# other operations exists, skip
return []

for arg1, arg2 in data_conditions:
# is fetched?
table2 = self.get_table_for_column(arg2)
fetch_step = self.tables_fetch_step.get(table2.index)

if fetch_step is None:
continue

# extract distinct values
# remove aliases
arg1 = Identifier(parts=[arg1.parts[-1]])
arg2 = Identifier(parts=[arg2.parts[-1]])

query2 = Select(targets=[arg2], distinct=True)
subselect_step = SubSelectStep(query2, fetch_step.result)
subselect_step = self.add_plan_step(subselect_step)

conditions.append(BinaryOperation(
op='in',
args=[
arg1,
Parameter(subselect_step.result)
]
))

return conditions

def process_predictor(self, item, query_in):
if len(self.step_stack) == 0:
raise NotImplementedError("Predictor can't be first element of join syntax")
Expand Down
24 changes: 12 additions & 12 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def test_join_predictor_plan(self):
"""
query = parse_sql(sql)

query_step = parse_sql("select tab1.column1, pred.predicted")
query_step.from_table = Parameter(Result(2))
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
Expand Down Expand Up @@ -75,7 +73,7 @@ def test_join_predictor_plan_aliases(self):
plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}})

assert plan.steps == expected_plan.steps


def test_join_predictor_plan_limit(self):

Expand Down Expand Up @@ -116,7 +114,7 @@ def test_join_predictor_plan_limit(self):
plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}})

assert plan.steps == expected_plan.steps


# def test_join_predictor_error_when_filtering_on_predictions(self):
# """
Expand Down Expand Up @@ -673,15 +671,16 @@ def test_complex_subselect(self):

sql = '''
select t2.x, m.id, (select a from int.tab0 where x=0) from int.tab1 t1
join int.tab2 t2 on t1.x = t2.x
join int.tab2 t2 on t1.x = t2.a
join mindsdb.pred m
where m.a=(select a from int.tab3 where x=3)
and t2.x=(select a from int.tab4 where x=4)
and t1.b=1 and t2.b=2 and t1.a = t2.a
'''

q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 ')
q_table2.where.args[0].args[1] = Parameter(Result(2))
q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 AND a IN 1')
q_table2.where.args[0].args[0].args[1] = Parameter(Result(2))
q_table2.where.args[1].args[1] = Parameter(Result(4))

subquery = parse_sql("""
select t2.x, m.id, x
Expand All @@ -708,22 +707,23 @@ def test_complex_subselect(self):
# tables
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as t1 where b=1')),
SubSelectStep(dataframe=Result(3), query=Select(targets=[Identifier('x')], distinct=True)),
FetchDataframeStep(integration='int', query=q_table2),
JoinStep(left=Result(3), right=Result(4),
JoinStep(left=Result(3), right=Result(5),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN,
condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.x')])
condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.a')])
)
),
# model
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(5),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(6),
predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': Result(1)}),
JoinStep(left=Result(5), right=Result(6),
JoinStep(left=Result(6), right=Result(7),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN)),
QueryStep(subquery, from_table=Result(7)),
QueryStep(subquery, from_table=Result(8)),
],
)
plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}})
Expand Down
44 changes: 24 additions & 20 deletions tests/test_planner/test_join_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_join_tables_plan(self):
query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
)
)
Expand All @@ -35,7 +35,7 @@ def test_join_tables_plan(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -45,13 +45,13 @@ def test_join_tables_plan(self):
)

assert plan.steps == expected_plan.steps


def test_join_tables_where_plan(self):
query = parse_sql('''
SELECT tab1.column1, tab2.column1, tab2.column2
FROM int.tab1
INNER JOIN int2.tab2 ON tab1.column1 = tab2.column1
INNER JOIN int2.tab2 ON tab1.column1 > tab2.column1
WHERE ((tab1.column1 = 1)
AND (tab2.column1 = 0))
AND (tab1.column3 = tab2.column3)
Expand All @@ -71,7 +71,7 @@ def test_join_tables_where_plan(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -90,7 +90,7 @@ def test_join_tables_plan_groupby(self):
Function('sum', args=[Identifier('tab2.column2')], alias=Identifier('total'))],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
),
Expand All @@ -117,7 +117,7 @@ def test_join_tables_plan_groupby(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -126,13 +126,13 @@ def test_join_tables_plan_groupby(self):
],
)
assert plan.steps == expected_plan.steps


def test_join_tables_plan_limit_offset(self):
query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
),
limit=Constant(10),
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_join_tables_plan_limit_offset(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -171,13 +171,13 @@ def test_join_tables_plan_limit_offset(self):
)

assert plan.steps == expected_plan.steps


def test_join_tables_plan_order_by(self):
query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')],
from_table=Join(left=Identifier('int.tab1'),
right=Identifier('int2.tab2'),
condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
),
limit=Constant(10),
Expand All @@ -203,7 +203,7 @@ def test_join_tables_plan_order_by(self):
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(op='=',
condition=BinaryOperation(op='>',
args=[Identifier('tab1.column1'),
Identifier('tab2.column1')]),
join_type=JoinType.INNER_JOIN
Expand All @@ -213,7 +213,7 @@ def test_join_tables_plan_order_by(self):
)

assert plan.steps == expected_plan.steps


# This quiery should be sent to integration without raising exception
# def test_join_tables_where_ambigous_column_error(self):
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_join_tables_disambiguate_identifiers_in_condition(self):

for i in range(len(plan.steps)):
assert plan.steps[i] == expected_plan.steps[i]


def _disabled_test_join_tables_error_on_unspecified_table_in_condition(self):
# disabled: identifier can be environment of system variable
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_join_tables_plan_default_namespace(self):
def test_complex_join_tables(self):
query = parse_sql('''
select * from int1.tbl1 t1
right join int2.tbl2 t2 on t1.id=t2.id
right join int2.tbl2 t2 on t1.id>t2.id
join pred m
left join tbl3 on tbl3.id=t1.id
where t1.a=1 and t2.b=2 and 1=1
Expand All @@ -337,6 +337,9 @@ def test_complex_join_tables(self):
subquery = copy.deepcopy(query)
subquery.from_table = None

q_table3 = parse_sql('select * from tbl3 where id in 0')
q_table3.where.args[1] = Parameter(Result(5))

plan = plan_query(query, integrations=['int1', 'int2', 'proj'], default_namespace='proj',
predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}])

Expand All @@ -349,7 +352,7 @@ def test_complex_join_tables(self):
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(
op='=',
op='>',
args=[Identifier('t1.id'),
Identifier('t2.id')]),
join_type=JoinType.RIGHT_JOIN)),
Expand All @@ -359,17 +362,18 @@ def test_complex_join_tables(self):
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN)),
FetchDataframeStep(integration='proj', query=parse_sql('select * from tbl3')),
SubSelectStep(dataframe=Result(0), query=Select(targets=[Identifier('id')], distinct=True)),
FetchDataframeStep(integration='proj', query=q_table3),
JoinStep(left=Result(4),
right=Result(5),
right=Result(6),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
condition=BinaryOperation(
op='=',
args=[Identifier('tbl3.id'),
Identifier('t1.id')]),
join_type=JoinType.LEFT_JOIN)),
QueryStep(subquery, from_table=Result(6)),
QueryStep(subquery, from_table=Result(7)),
]
)

Expand Down
Loading