Skip to content

Commit 36199f8

Browse files
authored
feat: add metadata based filtering to VectorStoreOptions (#564)
1 parent f2b941b commit 36199f8

File tree

13 files changed

+3125
-2888
lines changed

13 files changed

+3125
-2888
lines changed

docs/how-to/document_search/search-documents.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,35 @@ Searching for elements is performed using a vector store. [`DocumentSearch`][rag
7373

7474
To learn more about using Hybrid Search, refer to [How to Perform Hybrid Search with Multiple Vector Stores](../vector_stores/hybrid.md).
7575

76+
## Limit results with metadata-based filtering
77+
78+
You can filter search results based on document metadata using the `where` clause in `VectorStoreOptions`. This allows you to narrow down results to specific document types, sources, or any other metadata fields you've defined.
79+
80+
```python
81+
from ragbits.core.vector_stores.base import VectorStoreOptions
82+
from ragbits.document_search import DocumentSearch, DocumentSearchOptions
83+
84+
# Create vector store options with metadata filtering
85+
vector_store_options = VectorStoreOptions(
86+
k=2, # Number of results to return
87+
score_threshold=0.6, # Minimum similarity score
88+
where={"document_meta": {"document_type": "txt"}} # Filter by document type
89+
)
90+
91+
# Create document search options with the vector store options
92+
options = DocumentSearchOptions(vector_store_options=vector_store_options)
93+
94+
# Search with the filtering options
95+
results = await document_search.search("Your search query", options=options)
96+
```
97+
98+
The `where` clause supports various filtering conditions. For example, you can filter by:
99+
- Document type
100+
- Source
101+
- Custom metadata fields
102+
103+
This filtering happens at the vector store level, making the search more efficient by reducing the number of documents that need to be processed.
104+
76105
## Rephrase query
77106

78107
By default, the input query is provided directly to the embedding model. However, there is an option to add an additional step before vector search. Ragbits offers several common rephrasing techniques that can be utilized to refine the query and generate better embeddings for retrieval.

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Allow to limit VectorStore results by metadata (#564)
56
- Switch from imghdr to filetype for image file type check (#563)
67
- Remove prompt lab (#549)
78
- Add batched() helper method to utils (#555)

packages/ragbits-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ dependencies = [
4949

5050
[project.optional-dependencies]
5151
chroma = [
52-
"chromadb>=0.6.3,<1.0.0",
52+
"chromadb>=1.0.0,<2.0.0",
5353
]
5454
local = [
5555
"sentence-transformers>=4.0.2,<5.0.0",

packages/ragbits-core/src/ragbits/core/vector_stores/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ragbits.core.utils.config_handling import ConfigurableComponent, ObjectConstructionConfig
1414
from ragbits.core.utils.pydantic import SerializableBytes
1515

16-
WhereQuery = dict[str, str | int | float | bool]
16+
WhereQuery = dict[str, str | int | float | bool | dict]
1717

1818

1919
class VectorStoreEntry(BaseModel):
@@ -69,10 +69,13 @@ class VectorStoreOptions(Options):
6969
Note that this is based on score, which may be different from the raw
7070
similarity metric used by the vector store (see `VectorStoreResult`
7171
for more details).
72+
where: The filter dictionary - the keys are the field names and the values are the values to filter by.
73+
Not specifying the key means no filtering.
7274
"""
7375

7476
k: int = 5
7577
score_threshold: float | None = None
78+
where: WhereQuery | None = None
7679

7780

7881
VectorStoreOptionsT = TypeVar("VectorStoreOptionsT", bound=VectorStoreOptions)

packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from uuid import UUID
44

55
import chromadb
6-
from chromadb.api import ClientAPI, types
6+
from chromadb.api import ClientAPI
7+
from chromadb.api.types import IncludeMetadataDocuments, IncludeMetadataDocumentsEmbeddingsDistances
78
from typing_extensions import Self
89

910
from ragbits.core.audit.traces import trace
@@ -193,15 +194,13 @@ async def retrieve(
193194
query_vector = (await self._embedder.embed_text([text]))[0]
194195
query_vector = cast(list[float], query_vector)
195196

197+
where_dict = self._create_chroma_filter(merged_options.where)
198+
196199
results = self._collection.query(
197200
query_embeddings=query_vector,
198201
n_results=merged_options.k,
199-
include=[
200-
types.IncludeEnum.metadatas,
201-
types.IncludeEnum.embeddings,
202-
types.IncludeEnum.distances,
203-
types.IncludeEnum.documents,
204-
],
202+
include=IncludeMetadataDocumentsEmbeddingsDistances,
203+
where=where_dict,
205204
)
206205

207206
ids = [id for batch in results.get("ids", []) for id in batch]
@@ -266,14 +265,13 @@ async def list(
266265
with trace(
267266
where=where, collection=self._collection, index_name=self._index_name, limit=limit, offset=offset
268267
) as outputs:
269-
# Cast `where` to chromadb's Where type
270-
where_chroma: chromadb.Where | None = dict(where) if where else None
268+
where_chroma = self._create_chroma_filter(where)
271269

272270
results = self._collection.get(
273271
where=where_chroma,
274272
limit=limit,
275273
offset=offset,
276-
include=[types.IncludeEnum.metadatas, types.IncludeEnum.documents],
274+
include=IncludeMetadataDocuments,
277275
)
278276

279277
ids = results.get("ids") or []
@@ -301,3 +299,22 @@ async def list(
301299
def _flatten_metadata(metadata: dict) -> dict:
302300
"""Flattens the metadata dictionary. Removes any None values as they are not supported by ChromaDB."""
303301
return {k: v for k, v in flatten_dict(metadata).items() if v is not None}
302+
303+
@staticmethod
304+
def _create_chroma_filter(where: WhereQuery | None) -> chromadb.Where | None:
305+
"""
306+
Creates a ChromaDB filter from a WhereQuery.
307+
308+
Args:
309+
where: The filter dictionary - the keys are the field names and the values are the values to filter by.
310+
311+
Returns:
312+
The ChromaDB filter.
313+
"""
314+
if not where:
315+
return None
316+
317+
# If there are multiple filters, combine them with $and
318+
if len(where) > 1:
319+
return cast(chromadb.Where, {"$and": [{k: v} for k, v in flatten_dict(where).items()]})
320+
return cast(chromadb.Where, where)

packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ async def retrieve(
9090
results: list[VectorStoreResult] = []
9191

9292
for entry_id, vector in self._embeddings.items():
93+
entry = self._entries[entry_id]
94+
95+
# Apply metadata filtering
96+
if merged_options.where and not all(
97+
entry.metadata.get(key) == value for key, value in merged_options.where.items()
98+
):
99+
continue
100+
93101
# Calculate score based on vector type
94102
if isinstance(query_vector, SparseVector) and isinstance(vector, SparseVector):
95103
# For sparse vectors, use dot product between query and document vectors

packages/ragbits-core/src/ragbits/core/vector_stores/pgvector.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import re
3-
from typing import Any, NamedTuple, cast
3+
from typing import Any, NamedTuple
44
from uuid import UUID
55

66
import asyncpg
@@ -173,13 +173,19 @@ def _create_retrieve_query(
173173
# _table_name has been validated in the class constructor, and it is a valid table name.
174174
query = f"SELECT *, vector {distance_operator} $1 as distance, {score_formula} as score FROM {self._table_name}" # noqa S608
175175

176-
values: list[Any] = [
177-
self._vector_to_string(vector),
178-
]
176+
values: list[Any] = [self._vector_to_string(vector)]
177+
where_clauses = []
179178

180179
if query_options.score_threshold is not None:
181-
query += " WHERE score >= $2"
182-
values.extend([query_options.score_threshold])
180+
where_clauses.append("score >= $" + str(len(values) + 1))
181+
values.append(query_options.score_threshold)
182+
183+
if query_options.where:
184+
where_clauses.append(f"metadata @> ${len(values) + 1}")
185+
values.append(json.dumps(query_options.where))
186+
187+
if where_clauses:
188+
query += " WHERE " + " AND ".join(where_clauses)
183189

184190
query += " ORDER BY distance"
185191

@@ -351,25 +357,23 @@ async def retrieve(
351357
Returns:
352358
The retrieved entries.
353359
"""
354-
query_options = (self.default_options | options) if options else self.default_options
360+
merged_options = (self.default_options | options) if options else self.default_options
361+
355362
with trace(
356363
text=text,
364+
options=merged_options.dict(),
357365
table_name=self._table_name,
358-
query_options=query_options,
359366
vector_size=self._vector_size,
360367
distance_method=self._distance_method,
361368
embedder=repr(self._embedder),
362369
embedding_type=self._embedding_type,
363370
) as outputs:
364-
vector = (await self._embedder.embed_text([text]))[0]
365-
vector = cast(list[float], vector)
366-
367-
query_options = (self.default_options | options) if options else self.default_options
368-
retrieve_query, values = self._create_retrieve_query(vector, query_options)
371+
query_vector = (await self._embedder.embed_text([text]))[0]
372+
query, values = self._create_retrieve_query(query_vector, merged_options)
369373

370374
try:
371375
async with self._client.acquire() as conn:
372-
results = await conn.fetch(retrieve_query, *values)
376+
results = await conn.fetch(query, *values)
373377

374378
outputs.results = [
375379
VectorStoreResult(

packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
EmbeddingType,
2525
VectorStoreEntry,
2626
VectorStoreOptions,
27-
VectorStoreOptionsT,
2827
VectorStoreResult,
2928
VectorStoreWithEmbedder,
3029
WhereQuery,
@@ -214,7 +213,11 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
214213
wait=True,
215214
)
216215

217-
async def retrieve(self, text: str, options: VectorStoreOptionsT | None = None) -> list[VectorStoreResult]:
216+
async def retrieve(
217+
self,
218+
text: str,
219+
options: VectorStoreOptions | None = None,
220+
) -> list[VectorStoreResult]:
218221
"""
219222
Retrieves entries from the Qdrant collection based on vector similarity.
220223
@@ -236,7 +239,7 @@ async def retrieve(self, text: str, options: VectorStoreOptionsT | None = None)
236239
)
237240
with trace(
238241
text=text,
239-
options=merged_options,
242+
options=merged_options.dict(),
240243
index_name=self._index_name,
241244
distance_method=self._distance_method,
242245
embedder=repr(self._embedder),
@@ -252,6 +255,7 @@ async def retrieve(self, text: str, options: VectorStoreOptionsT | None = None)
252255
score_threshold=score_threshold,
253256
with_payload=True,
254257
with_vectors=True,
258+
query_filter=self._create_qdrant_filter(merged_options.where),
255259
)
256260

257261
outputs.results = []
@@ -290,16 +294,19 @@ async def remove(self, ids: list[UUID]) -> None:
290294
)
291295

292296
@staticmethod
293-
def _create_qdrant_filter(where: WhereQuery) -> Filter:
297+
def _create_qdrant_filter(where: WhereQuery | None) -> Filter:
294298
"""
295299
Creates the QdrantFilter from the given WhereQuery.
296300
297301
Args:
298-
where: The WhereQuery to filter.
302+
where: The WhereQuery to filter. If None, returns an empty filter.
299303
300304
Returns:
301305
The created filter.
302306
"""
307+
if where is None:
308+
return Filter(must=[])
309+
303310
where = flatten_dict(where) # type: ignore
304311

305312
return Filter(

packages/ragbits-core/tests/integration/vector_stores/test_vector_store.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,38 @@ async def test_handling_document_ingestion_with_different_content_and_verifying_
250250
assert document_1_content in document_contents
251251
assert document_2_new_content in document_contents
252252
assert document_2_content not in document_contents
253+
254+
255+
async def test_vector_store_retrieve_with_where_clause(
256+
text_vector_store: VectorStoreWithDenseEmbedder,
257+
vector_store_entries: list[VectorStoreEntry],
258+
) -> None:
259+
await text_vector_store.store(vector_store_entries)
260+
261+
# Test with a simple where clause
262+
results = await text_vector_store.retrieve(
263+
text="foo",
264+
options=VectorStoreOptions(
265+
where={
266+
"foo": "bar",
267+
"nested_foo": {"nested_bar": "nested_baz"},
268+
}
269+
),
270+
)
271+
272+
# Should only return the first entry which matches both conditions
273+
assert len(results) == 1
274+
assert results[0].entry.id == vector_store_entries[0].id
275+
assert results[0].entry.metadata["foo"] == "bar"
276+
assert results[0].entry.metadata["nested_foo"]["nested_bar"] == "nested_baz"
277+
278+
# Test with a where clause that matches no entries
279+
results = await text_vector_store.retrieve(
280+
text="foo",
281+
options=VectorStoreOptions(
282+
where={
283+
"foo": "nonexistent",
284+
}
285+
),
286+
)
287+
assert len(results) == 0

0 commit comments

Comments
 (0)