Skip to content

Commit

Permalink
Make infinity adapt to condision exist.
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinHuSh committed Jan 26, 2025
1 parent b4303f6 commit 82676fd
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 14 deletions.
10 changes: 9 additions & 1 deletion api/apps/kb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import DOC_ENGINE
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils import get_uuid
from api.db import StatusEnum, FileSource
Expand Down Expand Up @@ -96,6 +97,13 @@ def update():
return get_data_error_result(
message="Can't find this knowledgebase!")

if req.get("parser_id", "") == "tag" and DOC_ENGINE == "infinity":
return get_json_result(
data=False,
message='The chunk method Tag has not been supported by Infinity yet.',
code=settings.RetCode.OPERATING_ERROR
)

if req["name"].lower() != kb.name.lower() \
and len(
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
Expand All @@ -112,7 +120,7 @@ def update():
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)

e, kb = KnowledgebaseService.get_by_id(kb.id)
Expand Down
4 changes: 3 additions & 1 deletion rag/raptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __call__(self, chunks, random_state, callback=None):
start, end = 0, len(chunks)
if len(chunks) <= 1:
return
chunks = [(s, a) for s, a in chunks if len(a) > 0]
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]

def summarize(ck_idx, lock):
nonlocal chunks
Expand Down Expand Up @@ -125,6 +125,8 @@ def summarize(ck_idx, lock):
threads = []
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
if not ck_idx:
continue
threads.append(executor.submit(summarize, ck_idx, lock))
wait(threads, return_when=ALL_COMPLETED)
for th in threads:
Expand Down
2 changes: 1 addition & 1 deletion rag/utils/es_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseI
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exist":
if k == "exists":
bqry.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
Expand Down
66 changes: 55 additions & 11 deletions rag/utils/infinity_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,23 @@
logger = logging.getLogger('ragflow.infinity_conn')


def equivalent_condition_to_str(condition: dict) -> str | None:
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
clmns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)

def exists(cln):
nonlocal clmns
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
ty, de = clmns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"

cond = list()
for k, v in condition.items():
if not isinstance(k, str) or k in ["kb_id"] or not v:
Expand All @@ -61,8 +76,15 @@ def equivalent_condition_to_str(condition: dict) -> str | None:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
Expand Down Expand Up @@ -294,7 +316,11 @@ def search(
filter_cond = None
filter_fulltext = ""
if condition:
filter_cond = equivalent_condition_to_str(condition)
for indexName in indexNames:
table_name = f"{indexName}_{knowledgebaseIds[0]}"
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
break

for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
Expand Down Expand Up @@ -434,12 +460,21 @@ def insert(
self.createIdx(indexName, knowledgebaseId, vector_size)
table_instance = db_instance.get_table(table_name)

# embedding fields can't have a default value....
embedding_clmns = []
clmns = table_instance.show_columns().rows()
for n, ty, _, _ in clmns:
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
if not r:
continue
embedding_clmns.append((n, int(r.group(1))))

docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in d.items():
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]:
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
assert isinstance(v, list)
d[k] = "###".join(v)
elif re.search(r"_feas$", k):
Expand All @@ -454,6 +489,11 @@ def insert(
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
d[k] = "_".join(f"{num:08x}" for num in v)

for n, vs in embedding_clmns:
if n in d:
continue
d[n] = [0] * vs
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
Expand All @@ -475,11 +515,11 @@ def update(
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_instance = db_instance.get_table(table_name)
if "exist" in condition:
del condition["exist"]
filter = equivalent_condition_to_str(condition)
#if "exists" in condition:
# del condition["exists"]
filter = equivalent_condition_to_str(condition, table_instance)
for k, v in list(newValue.items()):
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]:
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
assert isinstance(v, list)
newValue[k] = "###".join(v)
elif re.search(r"_feas$", k):
Expand All @@ -496,9 +536,11 @@ def update(
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove" and v in [PAGERANK_FLD]:
elif k == "remove":
del newValue[k]
newValue[v] = 0
if v in [PAGERANK_FLD]:
newValue[v] = 0

logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
table_instance.update(filter, newValue)
self.connPool.release_conn(inf_conn)
Expand All @@ -508,14 +550,14 @@ def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
filter = equivalent_condition_to_str(condition)
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
)
return 0
filter = equivalent_condition_to_str(condition, table_instance)
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
Expand Down Expand Up @@ -553,7 +595,7 @@ def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[s
v = res[fieldnm][i]
if isinstance(v, Series):
v = list(v)
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]:
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
assert isinstance(v, str)
v = [kwd for kwd in v.split("###") if kwd]
elif fieldnm == "position_int":
Expand Down Expand Up @@ -584,6 +626,8 @@ def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: l
ans = {}
num_rows = len(res)
column_id = res["id"]
if fieldnm not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[fieldnm][i]
Expand Down

0 comments on commit 82676fd

Please sign in to comment.