Skip to content

Fix using predictor as a subselect #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mindsdb_sql/parser/ast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ def __eq__(self, other):
return self.to_tree() == other.to_tree() and to_single_line(str(self)) == to_single_line(str(other))
else:
return False

def __repr__(self):
sql = self.to_string().replace('\n', ' ')
if len(sql) > 500:
sql = sql[:500] + '...'
return f'{self.__class__.__name__}({sql})'
6 changes: 4 additions & 2 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def find_selects(node, **kwargs):

node.parentheses = False
last_step = self.plan_select(node)

node2 = Parameter(last_step.result)

return node2
Expand Down Expand Up @@ -273,10 +274,11 @@ def plan_select_identifier(self, query):
query.targets = utils.query_traversal(query.targets, find_selects)
utils.query_traversal(query.where, find_selects)

if len(query_info['predictors']) >= 1:
if query.from_table in query_info['predictors']:
# select from predictor
return self.plan_select_from_predictor(query)
else:
# fallback to integration
return self.plan_integration_select(query)

def plan_nested_select(self, select):
Expand Down Expand Up @@ -386,7 +388,7 @@ def plan_select_from_predictor(self, select):
)
)
project_step = self.plan_project(select, predictor_step.result)
return predictor_step, project_step
return project_step

def plan_predictor(self, query, table, predictor_namespace, predictor):
int_select = copy.deepcopy(query)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_planner/test_prepared_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def execute(self, step):
if name.isdigit():
name = step.predictor.parts[-2]

if name in ('pred', 'tp3', 'pr'):
if name in ('pred', 'tp3', 'pr', 'embedding_model'):
cols = [
{'name': 'id', 'type': 'int'},
{'name': 'value', 'type': 'str'},
Expand Down
97 changes: 97 additions & 0 deletions tests/test_planner/test_select_from_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,100 @@ def test_select_from_predictor_plan_other_ml(self):
assert plan.steps == expected_plan.steps



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests : )

class TestNestedSelect:

def test_using_predictor_in_subselect(self):
"""
Use predictor in subselect when selecting from integration
"""
sql = """
SELECT *
FROM chromadb.test_tabl
WHERE
search_vector = (
SELECT emebddings
FROM mindsdb.embedding_model
WHERE
content = 'some text'
)
"""
ast_tree = parse_sql(sql)
plan = plan_query(
ast_tree,
integrations=['chromadb'],
predictor_metadata=[
{'name': 'embedding_model', 'integration_name': 'mindsdb'}
]
)

expected_plan = [
ApplyPredictorRowStep(
step_num=0,
namespace='mindsdb',
predictor=Identifier(parts=['embedding_model']),
row_dict={'content': 'some text'}
),
ProjectStep(
step_num=1,
dataframe=Result(0),
columns=[Identifier(parts=['emebddings'])]
),
FetchDataframeStep(
step_num=2,
integration='chromadb',
query=Select(
targets=[Star()],
from_table=Identifier(parts=['test_tabl']),
where=BinaryOperation(
op='=',
args=[
Identifier(parts=['test_tabl', 'search_vector']),
Parameter(Result(1))
]
)
),
),
]

assert plan.steps == expected_plan

def test_using_integration_in_subselect(self):
"""
Use integration in subselect when selecting from predictor
"""
sql = """

SELECT *
FROM mindsdb.embedding_model
WHERE
content = (
SELECT content
FROM chromadb.test_tabl
LIMIT 1
)
"""
ast_tree = parse_sql(sql)
plan = plan_query(
ast_tree,
integrations=['chromadb'],
predictor_metadata=[
{'name': 'embedding_model', 'integration_name': 'mindsdb'}
]
)

expected_plan = [
FetchDataframeStep(
step_num=0,
integration='chromadb',
query=parse_sql('SELECT test_tabl.content AS content FROM test_tabl LIMIT 1')
),
ApplyPredictorRowStep(
step_num=1,
namespace='mindsdb',
predictor=Identifier(parts=['embedding_model']),
row_dict={'content': Parameter(Result(0))}
)
]

assert plan.steps == expected_plan