Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from redisvl.query.aggregate import AggregationQuery, HybridQuery
from redisvl.query.aggregate import (
AggregationQuery,
HybridQuery,
MultiVectorQuery,
Vector,
)
from redisvl.query.query import (
BaseQuery,
BaseVectorQuery,
Expand All @@ -21,4 +26,6 @@
"TextQuery",
"AggregationQuery",
"HybridQuery",
"MultiVectorQuery",
"Vector",
]
174 changes: 174 additions & 0 deletions redisvl/query/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from pydantic import BaseModel, field_validator
from redis.commands.search.aggregation import AggregateRequest, Desc

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType
from redisvl.utils.token_escaper import TokenEscaper
from redisvl.utils.utils import lazy_import

nltk = lazy_import("nltk")
nltk_stopwords = lazy_import("nltk.corpus.stopwords")


class Vector(BaseModel):
"""
Simple object containing the necessary arguments to perform a multi vector query.
"""

vector: Union[List[float], bytes]
field_name: str
dtype: str = "float32"
weight: float = 1.0

@field_validator("dtype")
@classmethod
def validate_dtype(cls, dtype: str) -> str:
try:
VectorDataType(dtype.upper())
except ValueError:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)

return dtype


class AggregationQuery(AggregateRequest):
"""
Base class for aggregation queries used to create aggregation queries for Redis.
Expand Down Expand Up @@ -227,3 +252,152 @@ def _build_query_string(self) -> str:
def __str__(self) -> str:
"""Return the string representation of the query."""
return " ".join([str(x) for x in self.build_args()])


class MultiVectorQuery(AggregationQuery):
"""
MultiVectorQuery allows for search over multiple vector fields in a document simulateously.
The final score will be a weighted combination of the individual vector similarity scores
following the formula:

score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... )

Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric.

.. code-block:: python

from redisvl.query import MultiVectorQuery, Vector
from redisvl.index import SearchIndex

index = SearchIndex.from_yaml("path/to/index.yaml")

vector_1 = Vector(
vector=[0.1, 0.2, 0.3],
field_name="text_vector",
dtype="float32",
weight=0.7,
)
vector_2 = Vector(
vector=[0.5, 0.5],
field_name="image_vector",
dtype="bfloat16",
weight=0.2,
)
vector_3 = Vector(
vector=[0.1, 0.2, 0.3],
field_name="text_vector",
dtype="float64",
weight=0.5,
)

query = MultiVectorQuery(
vectors=[vector_1, vector_2, vector_3],
filter_expression=None,
num_results=10,
return_fields=["field1", "field2"],
dialect=2,
)

results = index.query(query)
"""

_vectors: List[Vector]

def __init__(
self,
vectors: Union[Vector, List[Vector]],
return_fields: Optional[List[str]] = None,
filter_expression: Optional[Union[str, FilterExpression]] = None,
num_results: int = 10,
return_score: bool = False,
dialect: int = 2,
):
"""
Instantiates a MultiVectorQuery object.

Args:
vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search.
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
Defaults to None.
num_results (int, optional): The number of results to return. Defaults to 10.
return_score (bool): Whether to return the combined vector similarity score.
Defaults to False.
dialect (int, optional): The Redis dialect version. Defaults to 2.
"""

self._filter_expression = filter_expression
self._num_results = num_results

if isinstance(vectors, Vector):
self._vectors = [vectors]
else:
self._vectors = vectors # type: ignore

if not all([isinstance(v, Vector) for v in self._vectors]):
raise TypeError(
"vector arugment must be a Vector object or list of Vector objects."
)

query_string = self._build_query_string()
super().__init__(query_string)

# calculate the respective vector similarities
for i in range(len(self._vectors)):
self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"})

# construct the scoring string based on the vector similarity scores and weights
combined_scores = []
for i, w in enumerate([v.weight for v in self._vectors]):
combined_scores.append(f"@score_{i} * {w}")
combined_score_string = " + ".join(combined_scores)

self.apply(combined_score=combined_score_string)

self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore
self.dialect(dialect)
if return_fields:
self.load(*return_fields) # type: ignore[arg-type]

@property
def params(self) -> Dict[str, Any]:
"""Return the parameters for the aggregation.

Returns:
Dict[str, Any]: The parameters for the aggregation.
"""
params = {}
for i, (vector, dtype) in enumerate(
[(v.vector, v.dtype) for v in self._vectors]
):
if isinstance(vector, list):
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
params[f"vector_{i}"] = vector
return params

def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""

# base KNN query
range_queries = []
for i, (vector, field) in enumerate(
[(v.vector, v.field_name) for v in self._vectors]
):
range_queries.append(
f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}"
)

range_query = " | ".join(range_queries)

filter_expression = self._filter_expression
if isinstance(self._filter_expression, FilterExpression):
filter_expression = str(self._filter_expression)

if filter_expression:
return f"({range_query}) AND ({filter_expression})"
else:
return f"{range_query}"

def __str__(self) -> str:
"""Return the string representation of the query."""
return " ".join([str(x) for x in self.build_args()])
90 changes: 90 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,96 @@ def sample_data(sample_datetimes):
]


@pytest.fixture
def multi_vector_data(sample_datetimes):
return [
{
"user": "john",
"age": 18,
"job": "engineer",
"description": "engineers conduct trains that ride on train tracks",
"last_updated": sample_datetimes["low"].timestamp(),
"credit_score": "high",
"location": "-122.4194,37.7749",
"user_embedding": [0.1, 0.1, 0.5],
"image_embedding": [0.1, 0.1, 0.1, 0.1, 0.1],
"audio_embedding": [34, 18.5, -6.0, -12, 115, 96.5],
},
{
"user": "mary",
"age": 14,
"job": "doctor",
"description": "a medical professional who treats diseases and helps people stay healthy",
"last_updated": sample_datetimes["low"].timestamp(),
"credit_score": "low",
"location": "-122.4194,37.7749",
"user_embedding": [0.1, 0.1, 0.5],
"image_embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"audio_embedding": [0.0, -1.06, 4.55, -1.93, 0.0, 1.53],
},
{
"user": "nancy",
"age": 94,
"job": "doctor",
"description": "a research scientist specializing in cancers and diseases of the lungs",
"last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-122.4194,37.7749",
"user_embedding": [0.7, 0.1, 0.5],
"image_embedding": [0.1, 0.1, 0.3, 0.3, 0.5],
"audio_embedding": [2.75, -0.33, -3.01, -0.52, 5.59, -2.30],
},
{
"user": "tyler",
"age": 100,
"job": "engineer",
"description": "a software developer with expertise in mathematics and computer science",
"last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-110.0839,37.3861",
"user_embedding": [0.1, 0.4, 0.5],
"image_embedding": [-0.1, -0.2, -0.3, -0.4, -0.5],
"audio_embedding": [1.11, -6.73, 5.41, 1.04, 3.92, 0.73],
},
{
"user": "tim",
"age": 12,
"job": "dermatologist",
"description": "a medical professional specializing in diseases of the skin",
"last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-110.0839,37.3861",
"user_embedding": [0.4, 0.4, 0.5],
"image_embedding": [-0.1, 0.0, 0.6, 0.0, -0.9],
"audio_embedding": [0.03, -2.67, -2.08, 4.57, -2.33, 0.0],
},
{
"user": "taimur",
"age": 15,
"job": "CEO",
"description": "high stress, but financially rewarding position at the head of a company",
"last_updated": sample_datetimes["high"].timestamp(),
"credit_score": "low",
"location": "-110.0839,37.3861",
"user_embedding": [0.6, 0.1, 0.5],
"image_embedding": [1.1, 1.2, -0.3, -4.1, 5.0],
"audio_embedding": [0.68, 0.26, 2.08, 2.96, 0.01, 5.13],
},
{
"user": "joe",
"age": 35,
"job": "dentist",
"description": "like the tooth fairy because they'll take your teeth, but you have to pay them!",
"last_updated": sample_datetimes["high"].timestamp(),
"credit_score": "medium",
"location": "-110.0839,37.3861",
"user_embedding": [-0.1, -0.1, -0.5],
"image_embedding": [-0.8, 2.0, 3.1, 1.5, -1.6],
"audio_embedding": [0.91, 7.10, -2.14, -0.52, -6.08, -5.53],
},
]


def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--run-api-tests",
Expand Down
Loading