Skip to content

Commit 38bfb3d

Browse files
authored
Auto Hashing ID for VectorDB Classes (#4746) (#4789)
1 parent 960fbf0 commit 38bfb3d

File tree

4 files changed

+136
-24
lines changed

4 files changed

+136
-24
lines changed

autogen/agentchat/contrib/vectordb/base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import hashlib
2+
import os
13
from typing import (
24
Any,
35
Callable,
@@ -16,6 +18,8 @@
1618
Vector = Union[Sequence[float], Sequence[int]]
1719
ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does
1820

21+
HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))
22+
1923

2024
class Document(TypedDict):
2125
"""A Document is a record in the vector database.
@@ -26,7 +30,7 @@ class Document(TypedDict):
2630
embedding: Vector, Optional | the vector representation of the content.
2731
"""
2832

29-
id: ItemID
33+
id: Optional[ItemID]
3034
content: str
3135
metadata: Optional[Metadata]
3236
embedding: Optional[Vector]
@@ -108,6 +112,19 @@ def delete_collection(self, collection_name: str) -> Any:
108112
"""
109113
...
110114

115+
def generate_chunk_ids(chunks: List[str], hash_length: int = HASH_LENGTH) -> List[ItemID]:
116+
"""
117+
Generate chunk IDs to ensure non-duplicate uploads.
118+
119+
Args:
120+
chunks (list): A list of chunks (strings) to hash.
121+
hash_length (int): The desired length of the hash.
122+
123+
Returns:
124+
list: A list of generated chunk IDs.
125+
"""
126+
return [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:hash_length] for chunk in chunks]
127+
111128
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
112129
"""
113130
Insert documents into the collection of the vector database.

autogen/agentchat/contrib/vectordb/mongodb.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,17 @@ def _wait_for_document(self, collection: Collection, index_name: str, doc: Docum
123123
if query_result and query_result[0][0]["_id"] == doc["id"]:
124124
return
125125
sleep(_DELAY)
126-
127-
raise TimeoutError(f"Document {self.index_name} is not ready!")
126+
if (
127+
query_result
128+
and float(query_result[0][1]) == 1.0
129+
and query_result[0][0].get("metadata") == doc.get("metadata")
130+
):
131+
# Handles edge case where document is uploaded with a specific user-generated ID, then the identical content is uploaded with a hash generated ID.
132+
logger.warning(
133+
f"""Documents may be ready, the search has found identical content with a different ID and {"identical" if query_result[0][0].get("metadata") == doc.get("metadata") else "different"} metadata. Duplicate ID: {str(query_result[0][0]["_id"])}"""
134+
)
135+
else:
136+
raise TimeoutError(f"Document {self.index_name} is not ready!")
128137

129138
def _get_embedding_size(self):
130139
return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
@@ -275,33 +284,49 @@ def insert_docs(
275284
276285
For large numbers of Documents, insertion is performed in batches.
277286
287+
Documents are recommended to not have an ID field, as the method will generate Hashed ID's for them.
288+
278289
Args:
279-
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
290+
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`, which may contain an ID. Documents without ID's will have them generated.
280291
collection_name: str | The name of the collection. Default is None.
281292
upsert: bool | Whether to update the document if it exists. Default is False.
282293
batch_size: Number of documents to be inserted in each batch
294+
kwargs: Additional keyword arguments. Use `hash_length` to set the length of the hash generated ID's, use `overwrite_ids` to overwrite existing ID's with Hashed Values.
283295
"""
296+
hash_length = kwargs.get("hash_length")
297+
overwrite_ids = kwargs.get("overwrite_ids", False)
298+
299+
if any(doc.get("content") is None for doc in docs):
300+
raise ValueError("The document content is required.")
301+
284302
if not docs:
285303
logger.info("No documents to insert.")
286304
return
287305

306+
docs = deepcopy(docs)
288307
collection = self.get_collection(collection_name)
308+
309+
assert (
310+
len({doc.get("id") is None for doc in docs}) == 1
311+
), "Documents provided must all have ID's or all not have ID's"
312+
313+
if docs[0].get("id") is None or overwrite_ids:
314+
logger.info("No id field in the documents. The documents will be inserted with Hash generated IDs.")
315+
content = [doc["content"] for doc in docs]
316+
ids = (
317+
self.generate_chunk_ids(content, hash_length=hash_length)
318+
if hash_length
319+
else self.generate_chunk_ids(content)
320+
)
321+
docs = [{**doc, "id": id} for doc, id in zip(docs, ids)]
322+
289323
if upsert:
290324
self.update_docs(docs, collection.name, upsert=True)
325+
291326
else:
292-
# Sanity checking the first document
293-
if docs[0].get("content") is None:
294-
raise ValueError("The document content is required.")
295-
if docs[0].get("id") is None:
296-
raise ValueError("The document id is required.")
297-
298-
input_ids = set()
299-
result_ids = set()
300-
id_batch = []
301-
text_batch = []
302-
metadata_batch = []
303-
size = 0
304-
i = 0
327+
input_ids, result_ids = set(), set()
328+
id_batch, text_batch, metadata_batch = [], [], []
329+
size, i = 0, 0
305330
for doc in docs:
306331
id = doc["id"]
307332
text = doc["content"]
@@ -314,9 +339,7 @@ def insert_docs(
314339
if (i + 1) % batch_size == 0 or size >= 47_000_000:
315340
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
316341
input_ids.update(id_batch)
317-
id_batch = []
318-
text_batch = []
319-
metadata_batch = []
342+
id_batch, text_batch, metadata_batch = [], [], []
320343
size = 0
321344
i += 1
322345
if text_batch:
@@ -365,7 +388,8 @@ def _insert_batch(
365388
]
366389
# insert the documents in MongoDB Atlas
367390
insert_result = collection.insert_many(to_insert) # type: ignore
368-
return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
391+
# TODO Remove this. Replace by log like update_docs
392+
return insert_result.inserted_ids
369393

370394
def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
371395
"""Update documents, including their embeddings, in the Collection.
@@ -375,11 +399,14 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg
375399
Uses deepcopy to avoid changing docs.
376400
377401
Args:
378-
docs: List[Document] | A list of documents.
402+
docs: List[Document] | A list of documents, with ID, to ensure the correct document is updated.
379403
collection_name: str | The name of the collection. Default is None.
380404
kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
381405
"""
382-
406+
provided_doc_count = len(docs)
407+
docs = [doc for doc in docs if doc.get("id") is not None]
408+
if len(docs) != provided_doc_count:
409+
logger.info(f"{provided_doc_count - len(docs)} will not be updated, as they did not contain an ID")
383410
n_docs = len(docs)
384411
logger.info(f"Preparing to embed and update {n_docs=}")
385412
# Compute the embeddings

autogen/agentchat/contrib/vectordb/qdrant.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
2+
import hashlib
23
import logging
3-
import os
4+
import uuid
45
from typing import Callable, List, Optional, Sequence, Tuple, Union
56

67
from .base import Document, ItemID, QueryResults, VectorDB
@@ -155,6 +156,18 @@ def delete_collection(self, collection_name: str) -> None:
155156
"""
156157
return self.client.delete_collection(collection_name)
157158

159+
def generate_chunk_ids(chunks: List[str]) -> List[ItemID]:
160+
"""
161+
Generate chunk IDs to ensure non-duplicate uploads.
162+
163+
Args:
164+
chunks (list): A list of chunks (strings) to hash.
165+
166+
Returns:
167+
list: A list of generated chunk IDs.
168+
"""
169+
return [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]
170+
158171
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
159172
"""
160173
Insert documents into the collection of the vector database.

test/agentchat/contrib/vectordb/test_mongodb.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,28 @@ def example_documents() -> List[Document]:
107107
]
108108

109109

110+
@pytest.fixture
111+
def id_less_example_documents() -> List[Document]:
112+
"""No ID for Hashing Input Test"""
113+
return [
114+
Document(content="Stars are Big.", metadata={"a": 1}),
115+
Document(content="Atoms are Small.", metadata={"b": 1}),
116+
Document(content="Clouds are White.", metadata={"c": 1}),
117+
Document(content="Grass is Green.", metadata={"d": 1, "e": 2}),
118+
]
119+
120+
121+
@pytest.fixture
122+
def id_mix_example_documents() -> List[Document]:
123+
"""No ID for Hashing Input Test"""
124+
return [
125+
Document(id="123", content="Stars are Big.", metadata={"a": 1}),
126+
Document(content="Atoms are Small.", metadata={"b": 1}),
127+
Document(id="321", content="Clouds are White.", metadata={"c": 1}),
128+
Document(content="Grass is Green.", metadata={"d": 1, "e": 2}),
129+
]
130+
131+
110132
@pytest.fixture
111133
def db_with_indexed_clxn(collection_name):
112134
"""VectorDB with a collection created immediately"""
@@ -212,6 +234,39 @@ def test_insert_docs(db, collection_name, example_documents):
212234
assert len(found[0]["embedding"]) == 384
213235

214236

237+
def test_insert_docs_no_id(db, collection_name, id_less_example_documents):
238+
# Test that there's an active collection
239+
with pytest.raises(ValueError) as exc:
240+
db.insert_docs(id_less_example_documents)
241+
assert "No collection is specified" in str(exc.value)
242+
243+
# Create a collection
244+
db.delete_collection(collection_name)
245+
collection = db.create_collection(collection_name)
246+
247+
# Insert example documents
248+
db.insert_docs(id_less_example_documents, collection_name=collection_name)
249+
found = list(collection.find({}))
250+
assert len(found) == len(id_less_example_documents)
251+
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
252+
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
253+
# Check ids
254+
hash_values = set(db.generate_chunk_ids([content.get("content") for content in id_less_example_documents]))
255+
assert {doc["_id"] for doc in found} == hash_values
256+
# Check embedding lengths
257+
assert len(found[0]["embedding"]) == 384
258+
259+
260+
def test_insert_docs_mix_id(db, collection_name, id_mix_example_documents):
261+
# Test that there's an active collection
262+
with pytest.raises(ValueError) as exc:
263+
db.insert_docs(id_mix_example_documents)
264+
assert "No collection is specified" in str(exc.value)
265+
# Test that insert_docs does not accept mixed ID inserts
266+
with pytest.raises(AssertionError, match="Documents provided must all have ID's or all not have ID's"):
267+
db.insert_docs(id_mix_example_documents, collection_name, upsert=True)
268+
269+
215270
def test_update_docs(db_with_indexed_clxn, example_documents):
216271
db, collection = db_with_indexed_clxn
217272
# Use update_docs to insert new documents

0 commit comments

Comments
 (0)