Skip to content

Commit

Permalink
planner for knowledgebase
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhuishi-convect committed Sep 6, 2023
1 parent 1eacabc commit a536adf
Show file tree
Hide file tree
Showing 2 changed files with 403 additions and 2 deletions.
192 changes: 190 additions & 2 deletions mindsdb_sql/planner/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self,
integrations: list = None,
predictor_namespace=None,
predictor_metadata: list = None,
default_namespace: str = None):
default_namespace: str = None,
additional_metadata: list = None,
):
self.query = query
self.plan = QueryPlan()

Expand Down Expand Up @@ -90,8 +92,46 @@ def __init__(self,
self.projects = list(_projects)
self.databases = list(self.integrations.keys()) + self.projects

# additional metadata -- knowledge base
self.additional_metadata = {}
additional_metadata = additional_metadata or []
for metadata in additional_metadata:
if 'integration_name' not in metadata:
metadata['integration_name'] = self.predictor_namespace
idx = f'{metadata["integration_name"]}.{metadata["name"]}'.lower()

self._validate_knowledge_base_meta(metadata)
self.additional_metadata[idx] = metadata

self.statement = None

def _validate_knowledge_base_meta(self, metadata):
"""
Verify the entry for knowledge base metadata is valid
"""
TYPE_FIELD = "type"
if TYPE_FIELD not in metadata:
return
elif metadata[TYPE_FIELD] != "knowledge_base":
return
MODEL_FIELD = "model"
STORAGE_FIELD = "storage"
if MODEL_FIELD not in metadata:
raise PlanningException(f"Knowledge base metadata must contain a {MODEL_FIELD} field")
else:
# we enforce to specify a full qualified name for the model
# e.g., integration_name.model_name
model_name = metadata[MODEL_FIELD]
if len(model_name.split(".")) != 2:
raise PlanningException(f"Knowledge base model name must be in the format of integration_name.model_name")

if STORAGE_FIELD not in metadata:
raise PlanningException(f"Knowledge base metadata must contain a {STORAGE_FIELD} field")
else:
storage_name = metadata[STORAGE_FIELD]
if len(storage_name.split(".")) != 2:
raise PlanningException(f"Knowledge base storage name must be in the format of integration_name.table_name")

def is_predictor(self, identifier):
return self.get_predictor(identifier) is not None

Expand Down Expand Up @@ -124,6 +164,30 @@ def get_predictor(self, identifier):
info['name'] = name
return info

def get_knowledge_base(self, identifier):
name_parts = list(identifier.parts)
name = name_parts[-1]
namespace = None
if len(name_parts) > 1:
namespace = name_parts[-2]
else:
if self.default_namespace is not None:
namespace = self.default_namespace

idx_ar = [name]
if namespace is not None:
idx_ar.insert(0, namespace)

idx = '.'.join(idx_ar).lower()
info = self.additional_metadata.get(idx)
if info is not None and info.get("type") == "knowledge_base":
return info
else:
return

def is_knowledge_base(self, identifier):
return self.get_knowledge_base(identifier) is not None

def prepare_integration_select(self, database, query):
# replacement for 'utils.recursively_disambiguate_*' functions from utils
# main purpose: make tests working (don't change planner outputs)
Expand Down Expand Up @@ -1155,6 +1219,10 @@ def plan_create_table(self, query):

def plan_insert(self, query):
table = query.table
if self.is_knowledge_base(table):
# knowledgebase table
return self.plan_insert_knowledge_base(query)

if query.from_select is not None:
integration_name = query.table.parts[0]

Expand Down Expand Up @@ -1210,7 +1278,12 @@ def plan_select(self, query, integration=None):
from_table = query.from_table

if isinstance(from_table, Identifier):
return self.plan_select_identifier(query)
# decide from_table is a knowledgebase table or a table from integration
if self.is_knowledge_base(from_table):
# knowledgebase table
return self.plan_select_knowledege_base(query)
else:
return self.plan_select_identifier(query)
elif isinstance(from_table, Select):
return self.plan_nested_select(query)
elif isinstance(from_table, Join):
Expand Down Expand Up @@ -1260,6 +1333,121 @@ def plan_union(self, query):

return self.plan.add_step(UnionStep(left=query1.result, right=query2.result, unique=query.unique))


def plan_select_knowledege_base(self, query):
SEARCH_QUERY = "search_query" # TODO: need to make it as a constant
MODEL_FIELD = "model" # TODO: need to make it as a constant
STORAGE_FIELD = "storage" # TODO: need to make it as a constant


knowledegebase_metadata = self.get_knowledge_base(query.from_table)
vector_database_table = knowledegebase_metadata[STORAGE_FIELD]
model_name = knowledegebase_metadata[MODEL_FIELD]

CONTENT_FIELD = knowledegebase_metadata.get("content_field") or "content"
EMBEDDINGS_FIELD = knowledegebase_metadata.get("embeddings_field") or "embeddings"
SEARCH_VECTOR_FIELD = knowledegebase_metadata.get("search_vector_field") or "search_vector"

is_search_query_present = False
def find_search_query(node, **kwargs):
nonlocal is_search_query_present
if isinstance(node, Identifier) and node.parts[-1] == SEARCH_QUERY:
is_search_query_present = True

# decide predictor is needed in the query
# by detecting if a where clause involving field SEARCH_QUERY is present
# if yes, then we need to add additional step to the plan
# to apply the predictor to the search query
utils.query_traversal(
query.where,
callback=find_search_query
)

if not is_search_query_present:
# dispatch to the underlying storage table
query.from_table = Identifier(vector_database_table)
return self.plan_select(query)
else:
# rewrite the where clause
# search_query = 'some text'
# ->
# search_vector = (select embeddings from model_name where content = 'some text')
def rewrite_search_query_clause(node, **kwargs):
if isinstance(node, BinaryOperation):
if node.args[0] == Identifier(SEARCH_QUERY):
node.args[0] = Identifier(SEARCH_VECTOR_FIELD)
node.args[1] = Select(
targets=[Identifier(EMBEDDINGS_FIELD)],
from_table=Identifier(model_name),
where=BinaryOperation(
op="=",
args=[
Identifier(CONTENT_FIELD),
node.args[1]
]
)
)

utils.query_traversal(
query.where,
callback=rewrite_search_query_clause
)

# dispatch to the underlying storage table
query.from_table = Identifier(vector_database_table)
return self.plan_select(query)

def plan_insert_knowledge_base(self, query: Insert):
metadata = self.get_knowledge_base(query.table)
STORAGE_FIELD = "storage" # TODO: need to make it as a constant
MODEL_FIELD = "model" # TODO: need to make it as a constant
EMBEDDINGS_FIELD = metadata.get("embeddings_field") or "embeddings"

vector_database_table = metadata[STORAGE_FIELD]
model_name = metadata[MODEL_FIELD]

query.table = Identifier(vector_database_table)

if query.from_select is not None:
# detect if embeddings field is present in the columns list
# if so, we do not need to apply the predictor
# if not, we need to join the select with the model table
is_embeddings_field_present = False
def find_embeddings_field(node, **kwargs):
nonlocal is_embeddings_field_present
if isinstance(node, Identifier) and node.parts[-1] == EMBEDDINGS_FIELD:
is_embeddings_field_present = True

utils.query_traversal(
query.columns,
callback=find_embeddings_field
)

if is_embeddings_field_present:
return self.plan_insert(query)

# rewrite the select statement
# to join with the model table

select: Select = query.from_select
select.targets.append(Identifier(EMBEDDINGS_FIELD))
select.from_table = Select(
targets=copy.deepcopy(select.targets),
from_table=Join(
left=select.from_table,
right=Identifier(model_name),
join_type="JOIN"
)
)

# append the embeddings field to the columns list
if query.columns:
query.columns.append(Identifier(EMBEDDINGS_FIELD))

return self.plan_insert(query)
else:
raise NotImplementedError("Not implemented insert without select")

# method for compatibility
def from_query(self, query=None):
if query is None:
Expand Down
Loading

0 comments on commit a536adf

Please sign in to comment.