Skip to content

fix issue with nested vector fields and python 3.13 issubclass changes #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ]
pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, we definitely needed to test against 3.13

redisstack: [ "latest" ]
fail-fast: false
services:
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,7 @@ tests_sync/
# spelling cruft
*.dic

.idea
.idea

# version files
.tool-versions
19 changes: 10 additions & 9 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from .. import redis
from ..checks import has_redis_json, has_redisearch
from ..connections import get_redis_connection
from ..util import ASYNC_MODE
from ..util import ASYNC_MODE, has_numeric_inner_type, is_numeric_type
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
Expand Down Expand Up @@ -406,7 +406,6 @@ class RediSearchFieldTypes(Enum):


# TODO: How to handle Geo fields?
NUMERIC_TYPES = (float, int, decimal.Decimal)
DEFAULT_PAGE_SIZE = 1000


Expand Down Expand Up @@ -578,7 +577,7 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
)
elif field_type is bool:
return RediSearchFieldTypes.TAG
elif any(issubclass(field_type, t) for t in NUMERIC_TYPES):
elif is_numeric_type(field_type):
# Index numeric Python types as NUMERIC fields, so we can support
# range queries.
return RediSearchFieldTypes.NUMERIC
Expand Down Expand Up @@ -1805,7 +1804,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
schema = cls.schema_for_type(name, embedded_cls, field_info)
elif typ is bool:
schema = f"{name} TAG"
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
elif is_numeric_type(typ):
vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
)
Expand Down Expand Up @@ -2004,9 +2003,7 @@ def schema_for_type(
field_info, "vector_options", None
)
try:
is_vector = vector_options and any(
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
)
is_vector = vector_options and has_numeric_inner_type(typ)
except IndexError:
raise RedisModelError(
f"Vector field '{name}' must be annotated as a container type"
Expand Down Expand Up @@ -2104,7 +2101,11 @@ def schema_for_type(
# a proper type, we can pull the type information from the origin of the first argument.
if not isinstance(typ, type):
type_args = typing_get_args(field_info.annotation)
typ = type_args[0].__origin__
typ = (
getattr(type_args[0], "__origin__", type_args[0])
if type_args
else typ
)

# TODO: GEO field
if is_vector and vector_options:
Expand All @@ -2127,7 +2128,7 @@ def schema_for_type(
schema += " CASESENSITIVE"
elif typ is bool:
schema = f"{path} AS {index_field_name} TAG"
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
elif is_numeric_type(typ):
schema = f"{path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
if full_text_search is True:
Expand Down
26 changes: 26 additions & 0 deletions aredis_om/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import decimal
import inspect
from typing import Any, Type, get_args


def is_async_mode() -> bool:
Expand All @@ -10,3 +12,27 @@ async def f() -> None:


ASYNC_MODE = is_async_mode()

NUMERIC_TYPES = (float, int, decimal.Decimal)


def is_numeric_type(type_: Type[Any]) -> bool:
try:
return issubclass(type_, NUMERIC_TYPES)
except TypeError:
return False


def has_numeric_inner_type(type_: Type[Any]) -> bool:
"""
Check if the type has a numeric inner type.
"""
args = get_args(type_)

if not args:
return False

try:
return issubclass(args[0], NUMERIC_TYPES)
except TypeError:
return False
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redis-om"
version = "1.0.1-beta"
version = "1.0.2-beta"
description = "Object mappings, and more, for Redis."
authors = ["Redis OSS <[email protected]>"]
maintainers = ["Redis OSS <[email protected]>"]
Expand All @@ -22,6 +22,7 @@ classifiers = [
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: 3.13',
'Programming Language :: Python',
]
include=[
Expand Down
6 changes: 3 additions & 3 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ async def test_full_text_search_queries(members, m):
async def test_pagination_queries(members, m):
member1, member2, member3 = members

actual = await m.Member.find(m.Member.last_name == "Brookins").page()
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("id").page()

assert actual == [member1, member2]

actual = await m.Member.find().page(1, 1)
actual = await m.Member.find().sort_by("id").page(1, 1)

assert actual == [member2]

actual = await m.Member.find().page(0, 1)
actual = await m.Member.find().sort_by("id").page(0, 1)

assert actual == [member1]

Expand Down
6 changes: 4 additions & 2 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,10 @@ async def test_sorting(members, m):
async def test_case_sensitive(members, m):
member1, member2, member3 = members

actual = await m.Member.find(m.Member.first_name == "Andrew").all()
assert actual == [member1, member3]
actual = await m.Member.find(m.Member.first_name == "Andrew").sort_by("pk").all()
assert sorted(actual, key=lambda m: m.pk) == sorted(
[member1, member3], key=lambda m: m.pk
)

actual = await m.Member.find(m.Member.first_name == "andrew").all()
assert actual == []
Expand Down
45 changes: 43 additions & 2 deletions tests/test_knn_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,24 @@ class Meta:

class Member(BaseJsonModel, index=True):
name: str
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
embeddings: list[float] = Field([], vector_options=vector_field_options)
embeddings_score: Optional[float] = None

await Migrator().run()

return Member


@pytest_asyncio.fixture
async def n(key_prefix, redis):
class BaseJsonModel(JsonModel, abc.ABC):
class Meta:
global_key_prefix = key_prefix
database = redis

class Member(BaseJsonModel, index=True):
name: str
nested: list[list[float]] = Field([], vector_options=vector_field_options)
embeddings_score: Optional[float] = None

await Migrator().run()
Expand All @@ -45,7 +62,7 @@ def to_bytes(vectors: list[float]) -> bytes:
async def test_vector_field(m: Type[JsonModel]):
# Create a new instance of the Member model
vectors = [0.3 for _ in range(DIMENSIONS)]
member = m(name="seth", embeddings=[vectors])
member = m(name="seth", embeddings=vectors)

# Save the member to Redis
await member.save()
Expand All @@ -63,3 +80,27 @@ async def test_vector_field(m: Type[JsonModel]):

assert len(members) == 1
assert members[0].embeddings_score is not None


@py_test_mark_asyncio
async def test_nested_vector_field(n: Type[JsonModel]):
# Create a new instance of the Member model
vectors = [0.3 for _ in range(DIMENSIONS)]
member = n(name="seth", nested=[vectors])

# Save the member to Redis
await member.save()

knn = KNNExpression(
k=1,
vector_field=n.nested,
score_field=n.embeddings_score,
reference_vector=to_bytes(vectors),
)

query = n.find(knn=knn)

members = await query.all()

assert len(members) == 1
assert members[0].embeddings_score is not None