Skip to content

Commit

Permalink
Merge pull request #297 from mindsdb/fix/predictor-in-subselect
Browse files Browse the repository at this point in the history
Fix using predictor as a subselect
  • Loading branch information
yuhuishi-convect authored Sep 11, 2023
2 parents 0c83ba3 + d3026d5 commit fb7688f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 2 deletions.
6 changes: 6 additions & 0 deletions mindsdb_sql/parser/ast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ def __eq__(self, other):
return self.to_tree() == other.to_tree() and to_single_line(str(self)) == to_single_line(str(other))
else:
return False

def __repr__(self):
sql = self.to_string().replace('\n', ' ')
if len(sql) > 500:
sql = sql[:500] + '...'
return f'{self.__class__.__name__}({sql})'
7 changes: 6 additions & 1 deletion mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def find_selects(node, **kwargs):

node.parentheses = False
last_step = self.plan_select(node)

node2 = Parameter(last_step.result)

return node2
Expand Down Expand Up @@ -273,10 +274,14 @@ def plan_select_identifier(self, query):
query.targets = utils.query_traversal(query.targets, find_selects)
utils.query_traversal(query.where, find_selects)

# get info of updated query
query_info = self.get_query_info(query)

if len(query_info['predictors']) >= 1:
# select from predictor
return self.plan_select_from_predictor(query)
else:
# fallback to integration
return self.plan_integration_select(query)

def plan_nested_select(self, select):
Expand Down Expand Up @@ -386,7 +391,7 @@ def plan_select_from_predictor(self, select):
)
)
project_step = self.plan_project(select, predictor_step.result)
return predictor_step, project_step
return project_step

def plan_predictor(self, query, table, predictor_namespace, predictor):
int_select = copy.deepcopy(query)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_planner/test_prepared_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def execute(self, step):
if name.isdigit():
name = step.predictor.parts[-2]

if name in ('pred', 'tp3', 'pr'):
if name in ('pred', 'tp3', 'pr', 'embedding_model'):
cols = [
{'name': 'id', 'type': 'int'},
{'name': 'value', 'type': 'str'},
Expand Down
97 changes: 97 additions & 0 deletions tests/test_planner/test_select_from_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,100 @@ def test_select_from_predictor_plan_other_ml(self):
assert plan.steps == expected_plan.steps



class TestNestedSelect:

def test_using_predictor_in_subselect(self):
"""
Use predictor in subselect when selecting from integration
"""
sql = """
SELECT *
FROM chromadb.test_tabl
WHERE
search_vector = (
SELECT emebddings
FROM mindsdb.embedding_model
WHERE
content = 'some text'
)
"""
ast_tree = parse_sql(sql)
plan = plan_query(
ast_tree,
integrations=['chromadb'],
predictor_metadata=[
{'name': 'embedding_model', 'integration_name': 'mindsdb'}
]
)

expected_plan = [
ApplyPredictorRowStep(
step_num=0,
namespace='mindsdb',
predictor=Identifier(parts=['embedding_model']),
row_dict={'content': 'some text'}
),
ProjectStep(
step_num=1,
dataframe=Result(0),
columns=[Identifier(parts=['emebddings'])]
),
FetchDataframeStep(
step_num=2,
integration='chromadb',
query=Select(
targets=[Star()],
from_table=Identifier(parts=['test_tabl']),
where=BinaryOperation(
op='=',
args=[
Identifier(parts=['test_tabl', 'search_vector']),
Parameter(Result(1))
]
)
),
),
]

assert plan.steps == expected_plan

def test_using_integration_in_subselect(self):
"""
Use integration in subselect when selecting from predictor
"""
sql = """
SELECT *
FROM mindsdb.embedding_model
WHERE
content = (
SELECT content
FROM chromadb.test_tabl
LIMIT 1
)
"""
ast_tree = parse_sql(sql)
plan = plan_query(
ast_tree,
integrations=['chromadb'],
predictor_metadata=[
{'name': 'embedding_model', 'integration_name': 'mindsdb'}
]
)

expected_plan = [
FetchDataframeStep(
step_num=0,
integration='chromadb',
query=parse_sql('SELECT test_tabl.content AS content FROM test_tabl LIMIT 1')
),
ApplyPredictorRowStep(
step_num=1,
namespace='mindsdb',
predictor=Identifier(parts=['embedding_model']),
row_dict={'content': Parameter(Result(0))}
)
]

assert plan.steps == expected_plan

0 comments on commit fb7688f

Please sign in to comment.