diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index 6b46a56a..8ba08910 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -32,7 +32,9 @@ def __init__(self, integrations: list = None, predictor_namespace=None, predictor_metadata: list = None, - default_namespace: str = None): + default_namespace: str = None, + additional_metadata: list = None, + ): self.query = query self.plan = QueryPlan() @@ -90,8 +92,46 @@ def __init__(self, self.projects = list(_projects) self.databases = list(self.integrations.keys()) + self.projects + # additional metadata -- knowledge base + self.additional_metadata = {} + additional_metadata = additional_metadata or [] + for metadata in additional_metadata: + if 'integration_name' not in metadata: + metadata['integration_name'] = self.predictor_namespace + idx = f'{metadata["integration_name"]}.{metadata["name"]}'.lower() + + self._validate_knowledge_base_meta(metadata) + self.additional_metadata[idx] = metadata + self.statement = None + def _validate_knowledge_base_meta(self, metadata): + """ + Verify the entry for knowledge base metadata is valid + """ + TYPE_FIELD = "type" + if TYPE_FIELD not in metadata: + return + elif metadata[TYPE_FIELD] != "knowledge_base": + return + MODEL_FIELD = "model" + STORAGE_FIELD = "storage" + if MODEL_FIELD not in metadata: + raise PlanningException(f"Knowledge base metadata must contain a {MODEL_FIELD} field") + else: + # we enforce to specify a full qualified name for the model + # e.g., integration_name.model_name + model_name = metadata[MODEL_FIELD] + if len(model_name.split(".")) != 2: + raise PlanningException(f"Knowledge base model name must be in the format of integration_name.model_name") + + if STORAGE_FIELD not in metadata: + raise PlanningException(f"Knowledge base metadata must contain a {STORAGE_FIELD} field") + else: + storage_name = metadata[STORAGE_FIELD] + if len(storage_name.split(".")) != 2: + raise PlanningException(f"Knowledge base storage name must be in the format of integration_name.table_name") + def is_predictor(self, identifier): return self.get_predictor(identifier) is not None @@ -124,6 +164,30 @@ def get_predictor(self, identifier): info['name'] = name return info + def get_knowledge_base(self, identifier): + name_parts = list(identifier.parts) + name = name_parts[-1] + namespace = None + if len(name_parts) > 1: + namespace = name_parts[-2] + else: + if self.default_namespace is not None: + namespace = self.default_namespace + + idx_ar = [name] + if namespace is not None: + idx_ar.insert(0, namespace) + + idx = '.'.join(idx_ar).lower() + info = self.additional_metadata.get(idx) + if info is not None and info.get("type") == "knowledge_base": + return info + else: + return + + def is_knowledge_base(self, identifier): + return self.get_knowledge_base(identifier) is not None + def prepare_integration_select(self, database, query): # replacement for 'utils.recursively_disambiguate_*' functions from utils # main purpose: make tests working (don't change planner outputs) @@ -1155,6 +1219,10 @@ def plan_create_table(self, query): def plan_insert(self, query): table = query.table + if self.is_knowledge_base(table): + # knowledgebase table + return self.plan_insert_knowledge_base(query) + if query.from_select is not None: integration_name = query.table.parts[0] @@ -1210,7 +1278,12 @@ def plan_select(self, query, integration=None): from_table = query.from_table if isinstance(from_table, Identifier): - return self.plan_select_identifier(query) + # decide from_table is a knowledgebase table or a table from integration + if self.is_knowledge_base(from_table): + # knowledgebase table + return self.plan_select_knowledege_base(query) + else: + return self.plan_select_identifier(query) elif isinstance(from_table, Select): return self.plan_nested_select(query) elif isinstance(from_table, Join): @@ -1260,6 +1333,121 @@ def plan_union(self, query): return self.plan.add_step(UnionStep(left=query1.result, right=query2.result, unique=query.unique)) + + def plan_select_knowledege_base(self, query): + SEARCH_QUERY = "search_query" # TODO: need to make it as a constant + MODEL_FIELD = "model" # TODO: need to make it as a constant + STORAGE_FIELD = "storage" # TODO: need to make it as a constant + + + knowledegebase_metadata = self.get_knowledge_base(query.from_table) + vector_database_table = knowledegebase_metadata[STORAGE_FIELD] + model_name = knowledegebase_metadata[MODEL_FIELD] + + CONTENT_FIELD = knowledegebase_metadata.get("content_field") or "content" + EMBEDDINGS_FIELD = knowledegebase_metadata.get("embeddings_field") or "embeddings" + SEARCH_VECTOR_FIELD = knowledegebase_metadata.get("search_vector_field") or "search_vector" + + is_search_query_present = False + def find_search_query(node, **kwargs): + nonlocal is_search_query_present + if isinstance(node, Identifier) and node.parts[-1] == SEARCH_QUERY: + is_search_query_present = True + + # decide predictor is needed in the query + # by detecting if a where clause involving field SEARCH_QUERY is present + # if yes, then we need to add additional step to the plan + # to apply the predictor to the search query + utils.query_traversal( + query.where, + callback=find_search_query + ) + + if not is_search_query_present: + # dispatch to the underlying storage table + query.from_table = Identifier(vector_database_table) + return self.plan_select(query) + else: + # rewrite the where clause + # search_query = 'some text' + # -> + # search_vector = (select embeddings from model_name where content = 'some text') + def rewrite_search_query_clause(node, **kwargs): + if isinstance(node, BinaryOperation): + if node.args[0] == Identifier(SEARCH_QUERY): + node.args[0] = Identifier(SEARCH_VECTOR_FIELD) + node.args[1] = Select( + targets=[Identifier(EMBEDDINGS_FIELD)], + from_table=Identifier(model_name), + where=BinaryOperation( + op="=", + args=[ + Identifier(CONTENT_FIELD), + node.args[1] + ] + ) + ) + + utils.query_traversal( + query.where, + callback=rewrite_search_query_clause + ) + + # dispatch to the underlying storage table + query.from_table = Identifier(vector_database_table) + return self.plan_select(query) + + def plan_insert_knowledge_base(self, query: Insert): + metadata = self.get_knowledge_base(query.table) + STORAGE_FIELD = "storage" # TODO: need to make it as a constant + MODEL_FIELD = "model" # TODO: need to make it as a constant + EMBEDDINGS_FIELD = metadata.get("embeddings_field") or "embeddings" + + vector_database_table = metadata[STORAGE_FIELD] + model_name = metadata[MODEL_FIELD] + + query.table = Identifier(vector_database_table) + + if query.from_select is not None: + # detect if embeddings field is present in the columns list + # if so, we do not need to apply the predictor + # if not, we need to join the select with the model table + is_embeddings_field_present = False + def find_embeddings_field(node, **kwargs): + nonlocal is_embeddings_field_present + if isinstance(node, Identifier) and node.parts[-1] == EMBEDDINGS_FIELD: + is_embeddings_field_present = True + + utils.query_traversal( + query.columns, + callback=find_embeddings_field + ) + + if is_embeddings_field_present: + return self.plan_insert(query) + + # rewrite the select statement + # to join with the model table + + select: Select = query.from_select + select.targets.append(Identifier(EMBEDDINGS_FIELD)) + select.from_table = Select( + targets=copy.deepcopy(select.targets), + from_table=Join( + left=select.from_table, + right=Identifier(model_name), + join_type="JOIN" + ) + ) + + # append the embeddings field to the columns list + if query.columns: + query.columns.append(Identifier(EMBEDDINGS_FIELD)) + + return self.plan_insert(query) + else: + raise NotImplementedError("Not implemented insert without select") + # method for compatibility def from_query(self, query=None): if query is None: diff --git a/tests/test_planner/test_knowledege_base.py b/tests/test_planner/test_knowledege_base.py new file mode 100644 index 00000000..806cf6ef --- /dev/null +++ b/tests/test_planner/test_knowledege_base.py @@ -0,0 +1,213 @@ +# test planning for knowledge base related queries + +import pytest +from mindsdb_sql.parser.ast import * +from mindsdb_sql.planner import plan_query +from mindsdb_sql import parse_sql +from mindsdb_sql.planner.step_result import Result +from mindsdb_sql.planner.steps import * +from functools import partial + + +@pytest.fixture +def planner_context(): + integrations = [ + { + "name": "my_chromadb", + "type": "data", + }, + { + "name": "my_database", + "type": "data", + }, + ] + + predictors = [ + { + "name": "my_model", + "integration_name": "mindsdb", + }, + ] + + additional_context = [ + { + "name": "my_kb", + "type": "knowledge_base", + "model": "mindsdb.my_model", + "storage": "my_chromadb.my_table", + "search_vector_field": "search_vector", + "embeddings_field": "embeddings", + "content_field": "content", + } + ] + + return integrations, predictors, additional_context + + +def plan_sql(sql, *args, **kwargs): + return plan_query(parse_sql(sql, dialect="mindsdb"), *args, **kwargs) + + +def test_insert_into_kb(planner_context): + integration_context, predictor_context, additional_context = planner_context + _plan_sql = partial( + plan_sql, + default_namespace="mindsdb", + integrations=integration_context, + predictor_metadata=predictor_context, + additional_metadata=additional_context, + ) + + # insert into kb with values + sql = """ + INSERT INTO my_kb + (id, content, metadata) + VALUES + (1, 'hello world', '{"a": 1, "b": 2}'), + (2, 'hello world', '{"a": 1, "b": 2}'), + (3, 'hello world', '{"a": 1, "b": 2}'); + """ + # this will dispatch the underlying dataframes to the underlying model + # then it will dispatch the query to the underlying storage + # TODO: need to figure out what to do with this situation + + # insert into kb with select + sql = """ + INSERT INTO my_kb + (id, content, metadata) + SELECT + id, content, metadata + FROM my_database.my_table + """ + # this will join the subselect with the underlying model + # then it will dispatch the query to the underlying storage + equivalent_sql = """ + INSERT INTO my_chromadb.my_table + (id, content, metadata, embeddings) + SELECT + id, content, metadata, embeddings + FROM ( + SELECT + id, content, metadata, embeddings + FROM my_database.my_table + JOIN mindsdb.my_model + ) + """ + plan = _plan_sql(sql) + expected_plan = _plan_sql(equivalent_sql) + + assert plan.steps == expected_plan.steps + + +def test_select_from_kb(planner_context): + integration_context, predictor_context, additional_context = planner_context + _plan_sql = partial( + plan_sql, + default_namespace="mindsdb", + integrations=integration_context, + predictor_metadata=predictor_context, + additional_metadata=additional_context, + ) + + # select from kb without where + sql = """ + SELECT + id, content, embeddings, metadata + FROM my_kb + """ + # this will dispatch the query to the underlying storage + equivalent_sql = """ + SELECT + id, content, embeddings, metadata + FROM my_chromadb.my_table + """ + plan = _plan_sql(sql) + expected_plan = _plan_sql(equivalent_sql) + + assert plan.steps == expected_plan.steps + + # select from kb with search_query + sql = """ + SELECT + id, content, embeddings, metadata + FROM my_kb + WHERE + search_query = 'hello world' + """ + # this will dispatch the search_query to the underlying model + # then it will dispatch the query to the underlying storage + equivalent_sql = """ + SELECT + id, content, embeddings, metadata + FROM my_chromadb.my_table + WHERE + search_vector = ( + SELECT + embeddings + FROM mindsdb.my_model + WHERE + content = 'hello world' + ) + """ + plan = _plan_sql(sql) + expected_plan = _plan_sql(equivalent_sql) + + assert plan.steps == expected_plan.steps + + # select from kb with no search_query and just metadata query + sql = """ + SELECT + id, content, embeddings, metadata + FROM my_kb + WHERE + `metadata.a` = 1 + """ + # this will dispatch the whole query to the underlying storage + equivalent_sql = """ + SELECT + id, content, embeddings, metadata + FROM my_chromadb.my_table + WHERE + `metadata.a` = 1 + """ + plan = _plan_sql(sql) + expected_plan = _plan_sql(equivalent_sql) + + assert plan.steps == expected_plan.steps + + # select from kb with search_query and metadata query + sql = """ + SELECT + id, content, embeddings, metadata + FROM my_kb + WHERE + search_query = 'hello world' + AND + `metadata.a` = 1 + """ + # this will dispatch the search_query to the underlying model + # then it will dispatch the query to the underlying storage + equivalent_sql = """ + SELECT + id, content, embeddings, metadata + FROM my_chromadb.my_table + WHERE + search_vector = ( + SELECT + embeddings + FROM mindsdb.my_model + WHERE + content = 'hello world' + ) + AND + `metadata.a` = 1 + """ + plan = _plan_sql(sql) + expected_plan = _plan_sql(equivalent_sql) + + assert plan.steps == expected_plan.steps + + +@pytest.mark.skip(reason="not implemented") +def test_update_kb(): + ...