Skip to content

Commit

Permalink
Merge pull request #375 from mindsdb/model-col-map
Browse files Browse the repository at this point in the history
Model column mapping
  • Loading branch information
ea-rus authored May 2, 2024
2 parents fa80417 + 852ced3 commit 7c731a1
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
43 changes: 41 additions & 2 deletions mindsdb_sql/planner/plan_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TableInfo:
conditions: List = None
sub_select: ast.ASTNode = None
predictor_info: dict = None
join_condition = None


class PlanJoin:
Expand Down Expand Up @@ -147,7 +148,7 @@ def get_table_for_column(self, column: Identifier):
if parts in self.tables_idx:
return self.tables_idx[parts]

def get_join_sequence(self, node):
def get_join_sequence(self, node, condition=None):
sequence = []
if isinstance(node, Identifier):
# resolve identifier
Expand All @@ -158,6 +159,8 @@ def get_join_sequence(self, node):

table_info.predictor_info = self.planner.get_predictor(node)

if condition is not None:
table_info.join_condition = condition
sequence.append(table_info)

elif isinstance(node, Join):
Expand All @@ -168,7 +171,7 @@ def get_join_sequence(self, node):
for item in sequence2:
sequence.append(item)

sequence2 = self.get_join_sequence(node.right)
sequence2 = self.get_join_sequence(node.right, condition=node.condition)
if len(sequence2) != 1:
raise PlanningException('Unexpected join nesting behavior')

Expand Down Expand Up @@ -401,6 +404,37 @@ def process_table(self, item, query_in):
self.planner.plan.add_step(step)
self.step_stack.append(step)

def join_condition_to_columns_map(self, model_table):

columns_map = {}

def _check_conditions(node, **kwargs):
if not isinstance(node, BinaryOperation):
return

arg1, arg2 = node.args
if not (isinstance(arg1, Identifier) and isinstance(arg2, Identifier)):
return

table1 = self.get_table_for_column(arg1)
table2 = self.get_table_for_column(arg2)

if table1 is model_table:
# model is on the left
columns_map[arg1.parts[-1]] = arg2
elif table2 is model_table:
# model is on the right
columns_map[arg2.parts[-1]] = arg1
else:
# not found, skip
return

# exclude condition
node.args = [Constant(0), Constant(0)]

query_traversal(model_table.join_condition, _check_conditions)
return columns_map

def process_predictor(self, item, query_in):
if len(self.step_stack) == 0:
raise NotImplementedError("Predictor can't be first element of join syntax")
Expand All @@ -415,6 +449,10 @@ def process_predictor(self, item, query_in):
if predict_target is not None:
predict_target = predict_target.lower()

columns_map = None
if item.join_condition:
columns_map = self.join_condition_to_columns_map(item)

if item.conditions:
row_dict = {}
for i, el in enumerate(item.conditions):
Expand Down Expand Up @@ -450,5 +488,6 @@ def process_predictor(self, item, query_in):
predictor=item.table,
params=model_params,
row_dict=row_dict,
columns_map=columns_map,
))
self.step_stack.append(predictor_step)
9 changes: 8 additions & 1 deletion mindsdb_sql/planner/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,21 @@ def __init__(self, integration, query=None, raw_query=None, *args, **kwargs):

class ApplyPredictorStep(PlanStep):
"""Applies a mindsdb predictor on some dataframe and returns a new dataframe with predictions"""
def __init__(self, namespace, predictor, dataframe, params=None, row_dict=None, *args, **kwargs):
def __init__(self, namespace, predictor, dataframe, params: dict = None,
row_dict: dict = None, columns_map: dict = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.namespace = namespace
self.predictor = predictor
self.dataframe = dataframe
self.params = params

# columns to add to input data, struct: {column name: value}
self.row_dict = row_dict

# rename columns in input data, struct: {a str: b Identifier}
# renames b to a
self.columns_map = columns_map

if isinstance(dataframe, Result):
self.references.append(dataframe)

Expand Down
42 changes: 42 additions & 0 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,4 +771,46 @@ def test_model_join_model(self):
)
plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}})

assert plan.steps == expected_plan.steps

def test_model_column_map(self):

sql = '''
select * from int.tab1 a
join proj.pred.1 p on a.data1 = p.data2 and p.x = a.y
'''

# subquery = parse_sql("""
# select * from x
# where a.x=1 and 0=0 and p.ttt=2 and a.y=3 and 0=0
# """)
# subquery.from_table = None

query = parse_sql(sql)
expected_plan = QueryPlan(
steps=[
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as a')),
ApplyPredictorStep(
namespace='proj', dataframe=Result(0),
predictor=Identifier('pred.1', alias=Identifier('p')),
columns_map={'data2': Identifier('a.data1'), 'x': Identifier('a.y')}
),
JoinStep(left=Result(0), right=Result(1),
query=Join(
left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN,
condition=BinaryOperation('and', args=[
BinaryOperation('=', args=[Constant(0), Constant(0)]),
BinaryOperation('=', args=[Constant(0), Constant(0)])
])
),
),
],
)

plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb',
predictor_metadata=[{'name': 'pred', 'integration_name': 'proj', 'to_predict': ['ttt']}])

assert plan.steps == expected_plan.steps

0 comments on commit 7c731a1

Please sign in to comment.