Skip to content

Commit 8ddb6b7

Browse files
authored
fix issue with nested vector fields and python 3.13 issubclass changes (#699)
* fix issue with nested vector fields and python 3.13 issubclass changes * update version * update numeric checks to be utility functions * remove vscode settings * fix flaky test * fix flaky test
1 parent 7ca997c commit 8ddb6b7

File tree

8 files changed

+93
-19
lines changed

8 files changed

+93
-19
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
strategy:
7777
matrix:
7878
os: [ ubuntu-latest ]
79-
pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ]
79+
pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ]
8080
redisstack: [ "latest" ]
8181
fail-fast: false
8282
services:

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,7 @@ tests_sync/
143143
# spelling cruft
144144
*.dic
145145

146-
.idea
146+
.idea
147+
148+
# version files
149+
.tool-versions

aredis_om/model/model.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from .. import redis
4242
from ..checks import has_redis_json, has_redisearch
4343
from ..connections import get_redis_connection
44-
from ..util import ASYNC_MODE
44+
from ..util import ASYNC_MODE, has_numeric_inner_type, is_numeric_type
4545
from .encoders import jsonable_encoder
4646
from .render_tree import render_tree
4747
from .token_escaper import TokenEscaper
@@ -406,7 +406,6 @@ class RediSearchFieldTypes(Enum):
406406

407407

408408
# TODO: How to handle Geo fields?
409-
NUMERIC_TYPES = (float, int, decimal.Decimal)
410409
DEFAULT_PAGE_SIZE = 1000
411410

412411

@@ -578,7 +577,7 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
578577
)
579578
elif field_type is bool:
580579
return RediSearchFieldTypes.TAG
581-
elif any(issubclass(field_type, t) for t in NUMERIC_TYPES):
580+
elif is_numeric_type(field_type):
582581
# Index numeric Python types as NUMERIC fields, so we can support
583582
# range queries.
584583
return RediSearchFieldTypes.NUMERIC
@@ -1805,7 +1804,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
18051804
schema = cls.schema_for_type(name, embedded_cls, field_info)
18061805
elif typ is bool:
18071806
schema = f"{name} TAG"
1808-
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
1807+
elif is_numeric_type(typ):
18091808
vector_options: Optional[VectorFieldOptions] = getattr(
18101809
field_info, "vector_options", None
18111810
)
@@ -2004,9 +2003,7 @@ def schema_for_type(
20042003
field_info, "vector_options", None
20052004
)
20062005
try:
2007-
is_vector = vector_options and any(
2008-
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
2009-
)
2006+
is_vector = vector_options and has_numeric_inner_type(typ)
20102007
except IndexError:
20112008
raise RedisModelError(
20122009
f"Vector field '{name}' must be annotated as a container type"
@@ -2104,7 +2101,11 @@ def schema_for_type(
21042101
# a proper type, we can pull the type information from the origin of the first argument.
21052102
if not isinstance(typ, type):
21062103
type_args = typing_get_args(field_info.annotation)
2107-
typ = type_args[0].__origin__
2104+
typ = (
2105+
getattr(type_args[0], "__origin__", type_args[0])
2106+
if type_args
2107+
else typ
2108+
)
21082109

21092110
# TODO: GEO field
21102111
if is_vector and vector_options:
@@ -2127,7 +2128,7 @@ def schema_for_type(
21272128
schema += " CASESENSITIVE"
21282129
elif typ is bool:
21292130
schema = f"{path} AS {index_field_name} TAG"
2130-
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
2131+
elif is_numeric_type(typ):
21312132
schema = f"{path} AS {index_field_name} NUMERIC"
21322133
elif issubclass(typ, str):
21332134
if full_text_search is True:

aredis_om/util.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import decimal
12
import inspect
3+
from typing import Any, Type, get_args
24

35

46
def is_async_mode() -> bool:
@@ -10,3 +12,27 @@ async def f() -> None:
1012

1113

1214
ASYNC_MODE = is_async_mode()
15+
16+
NUMERIC_TYPES = (float, int, decimal.Decimal)
17+
18+
19+
def is_numeric_type(type_: Type[Any]) -> bool:
20+
try:
21+
return issubclass(type_, NUMERIC_TYPES)
22+
except TypeError:
23+
return False
24+
25+
26+
def has_numeric_inner_type(type_: Type[Any]) -> bool:
27+
"""
28+
Check if the type has a numeric inner type.
29+
"""
30+
args = get_args(type_)
31+
32+
if not args:
33+
return False
34+
35+
try:
36+
return issubclass(args[0], NUMERIC_TYPES)
37+
except TypeError:
38+
return False

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "redis-om"
3-
version = "1.0.1-beta"
3+
version = "1.0.2-beta"
44
description = "Object mappings, and more, for Redis."
55
authors = ["Redis OSS <[email protected]>"]
66
maintainers = ["Redis OSS <[email protected]>"]
@@ -22,6 +22,7 @@ classifiers = [
2222
'Programming Language :: Python :: 3.10',
2323
'Programming Language :: Python :: 3.11',
2424
'Programming Language :: Python :: 3.12',
25+
'Programming Language :: Python :: 3.13',
2526
'Programming Language :: Python',
2627
]
2728
include=[

tests/test_hash_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,15 @@ async def test_full_text_search_queries(members, m):
180180
async def test_pagination_queries(members, m):
181181
member1, member2, member3 = members
182182

183-
actual = await m.Member.find(m.Member.last_name == "Brookins").page()
183+
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("id").page()
184184

185185
assert actual == [member1, member2]
186186

187-
actual = await m.Member.find().page(1, 1)
187+
actual = await m.Member.find().sort_by("id").page(1, 1)
188188

189189
assert actual == [member2]
190190

191-
actual = await m.Member.find().page(0, 1)
191+
actual = await m.Member.find().sort_by("id").page(0, 1)
192192

193193
assert actual == [member1]
194194

tests/test_json_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -755,8 +755,10 @@ async def test_sorting(members, m):
755755
async def test_case_sensitive(members, m):
756756
member1, member2, member3 = members
757757

758-
actual = await m.Member.find(m.Member.first_name == "Andrew").all()
759-
assert actual == [member1, member3]
758+
actual = await m.Member.find(m.Member.first_name == "Andrew").sort_by("pk").all()
759+
assert sorted(actual, key=lambda m: m.pk) == sorted(
760+
[member1, member3], key=lambda m: m.pk
761+
)
760762

761763
actual = await m.Member.find(m.Member.first_name == "andrew").all()
762764
assert actual == []

tests/test_knn_expression.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,24 @@ class Meta:
2929

3030
class Member(BaseJsonModel, index=True):
3131
name: str
32-
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
32+
embeddings: list[float] = Field([], vector_options=vector_field_options)
33+
embeddings_score: Optional[float] = None
34+
35+
await Migrator().run()
36+
37+
return Member
38+
39+
40+
@pytest_asyncio.fixture
41+
async def n(key_prefix, redis):
42+
class BaseJsonModel(JsonModel, abc.ABC):
43+
class Meta:
44+
global_key_prefix = key_prefix
45+
database = redis
46+
47+
class Member(BaseJsonModel, index=True):
48+
name: str
49+
nested: list[list[float]] = Field([], vector_options=vector_field_options)
3350
embeddings_score: Optional[float] = None
3451

3552
await Migrator().run()
@@ -45,7 +62,7 @@ def to_bytes(vectors: list[float]) -> bytes:
4562
async def test_vector_field(m: Type[JsonModel]):
4663
# Create a new instance of the Member model
4764
vectors = [0.3 for _ in range(DIMENSIONS)]
48-
member = m(name="seth", embeddings=[vectors])
65+
member = m(name="seth", embeddings=vectors)
4966

5067
# Save the member to Redis
5168
await member.save()
@@ -63,3 +80,27 @@ async def test_vector_field(m: Type[JsonModel]):
6380

6481
assert len(members) == 1
6582
assert members[0].embeddings_score is not None
83+
84+
85+
@py_test_mark_asyncio
86+
async def test_nested_vector_field(n: Type[JsonModel]):
87+
# Create a new instance of the Member model
88+
vectors = [0.3 for _ in range(DIMENSIONS)]
89+
member = n(name="seth", nested=[vectors])
90+
91+
# Save the member to Redis
92+
await member.save()
93+
94+
knn = KNNExpression(
95+
k=1,
96+
vector_field=n.nested,
97+
score_field=n.embeddings_score,
98+
reference_vector=to_bytes(vectors),
99+
)
100+
101+
query = n.find(knn=knn)
102+
103+
members = await query.all()
104+
105+
assert len(members) == 1
106+
assert members[0].embeddings_score is not None

0 commit comments

Comments
 (0)