From 852ced3f0ca1ce1c00efa9e4546737c925ffb0a8 Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 2 May 2024 12:16:06 +0300 Subject: [PATCH] model column mapping --- mindsdb_sql/planner/plan_join.py | 43 +++++++++++++++++++++-- mindsdb_sql/planner/steps.py | 9 ++++- tests/test_planner/test_join_predictor.py | 42 ++++++++++++++++++++++ 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index da20aa54..c2bcb452 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -19,6 +19,7 @@ class TableInfo: conditions: List = None sub_select: ast.ASTNode = None predictor_info: dict = None + join_condition = None class PlanJoin: @@ -147,7 +148,7 @@ def get_table_for_column(self, column: Identifier): if parts in self.tables_idx: return self.tables_idx[parts] - def get_join_sequence(self, node): + def get_join_sequence(self, node, condition=None): sequence = [] if isinstance(node, Identifier): # resolve identifier @@ -158,6 +159,8 @@ def get_join_sequence(self, node): table_info.predictor_info = self.planner.get_predictor(node) + if condition is not None: + table_info.join_condition = condition sequence.append(table_info) elif isinstance(node, Join): @@ -168,7 +171,7 @@ def get_join_sequence(self, node): for item in sequence2: sequence.append(item) - sequence2 = self.get_join_sequence(node.right) + sequence2 = self.get_join_sequence(node.right, condition=node.condition) if len(sequence2) != 1: raise PlanningException('Unexpected join nesting behavior') @@ -401,6 +404,37 @@ def process_table(self, item, query_in): self.planner.plan.add_step(step) self.step_stack.append(step) + def join_condition_to_columns_map(self, model_table): + + columns_map = {} + + def _check_conditions(node, **kwargs): + if not isinstance(node, BinaryOperation): + return + + arg1, arg2 = node.args + if not (isinstance(arg1, Identifier) and isinstance(arg2, Identifier)): + return + + table1 = self.get_table_for_column(arg1) + table2 = self.get_table_for_column(arg2) + + if table1 is model_table: + # model is on the left + columns_map[arg1.parts[-1]] = arg2 + elif table2 is model_table: + # model is on the right + columns_map[arg2.parts[-1]] = arg1 + else: + # not found, skip + return + + # exclude condition + node.args = [Constant(0), Constant(0)] + + query_traversal(model_table.join_condition, _check_conditions) + return columns_map + def process_predictor(self, item, query_in): if len(self.step_stack) == 0: raise NotImplementedError("Predictor can't be first element of join syntax") @@ -415,6 +449,10 @@ def process_predictor(self, item, query_in): if predict_target is not None: predict_target = predict_target.lower() + columns_map = None + if item.join_condition: + columns_map = self.join_condition_to_columns_map(item) + if item.conditions: row_dict = {} for i, el in enumerate(item.conditions): @@ -450,5 +488,6 @@ def process_predictor(self, item, query_in): predictor=item.table, params=model_params, row_dict=row_dict, + columns_map=columns_map, )) self.step_stack.append(predictor_step) diff --git a/mindsdb_sql/planner/steps.py b/mindsdb_sql/planner/steps.py index 1a1c72ea..8fd8bcf7 100644 --- a/mindsdb_sql/planner/steps.py +++ b/mindsdb_sql/planner/steps.py @@ -138,14 +138,21 @@ def __init__(self, integration, query=None, raw_query=None, *args, **kwargs): class ApplyPredictorStep(PlanStep): """Applies a mindsdb predictor on some dataframe and returns a new dataframe with predictions""" - def __init__(self, namespace, predictor, dataframe, params=None, row_dict=None, *args, **kwargs): + def __init__(self, namespace, predictor, dataframe, params: dict = None, + row_dict: dict = None, columns_map: dict = None, *args, **kwargs): super().__init__(*args, **kwargs) self.namespace = namespace self.predictor = predictor self.dataframe = dataframe self.params = params + + # columns to add to input data, struct: {column name: value} self.row_dict = row_dict + # rename columns in input data, struct: {a str: b Identifier} + # renames b to a + self.columns_map = columns_map + if isinstance(dataframe, Result): self.references.append(dataframe) diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index f09e318c..3e043b4d 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -771,4 +771,46 @@ def test_model_join_model(self): ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + assert plan.steps == expected_plan.steps + + def test_model_column_map(self): + + sql = ''' + select * from int.tab1 a + join proj.pred.1 p on a.data1 = p.data2 and p.x = a.y + ''' + + # subquery = parse_sql(""" + # select * from x + # where a.x=1 and 0=0 and p.ttt=2 and a.y=3 and 0=0 + # """) + # subquery.from_table = None + + query = parse_sql(sql) + expected_plan = QueryPlan( + steps=[ + FetchDataframeStep(integration='int', + query=parse_sql('select * from tab1 as a')), + ApplyPredictorStep( + namespace='proj', dataframe=Result(0), + predictor=Identifier('pred.1', alias=Identifier('p')), + columns_map={'data2': Identifier('a.data1'), 'x': Identifier('a.y')} + ), + JoinStep(left=Result(0), right=Result(1), + query=Join( + left=Identifier('tab1'), + right=Identifier('tab2'), + join_type=JoinType.JOIN, + condition=BinaryOperation('and', args=[ + BinaryOperation('=', args=[Constant(0), Constant(0)]), + BinaryOperation('=', args=[Constant(0), Constant(0)]) + ]) + ), + ), + ], + ) + + plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', + predictor_metadata=[{'name': 'pred', 'integration_name': 'proj', 'to_predict': ['ttt']}]) + assert plan.steps == expected_plan.steps \ No newline at end of file