diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f31a11..56d094f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ] + pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ] redisstack: [ "latest" ] fail-fast: false services: diff --git a/.gitignore b/.gitignore index 8f21f6a..5e27823 100644 --- a/.gitignore +++ b/.gitignore @@ -143,4 +143,7 @@ tests_sync/ # spelling cruft *.dic -.idea \ No newline at end of file +.idea + +# version files +.tool-versions diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 5444991..5a5c75e 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -41,7 +41,7 @@ from .. import redis from ..checks import has_redis_json, has_redisearch from ..connections import get_redis_connection -from ..util import ASYNC_MODE +from ..util import ASYNC_MODE, has_numeric_inner_type, is_numeric_type from .encoders import jsonable_encoder from .render_tree import render_tree from .token_escaper import TokenEscaper @@ -406,7 +406,6 @@ class RediSearchFieldTypes(Enum): # TODO: How to handle Geo fields? -NUMERIC_TYPES = (float, int, decimal.Decimal) DEFAULT_PAGE_SIZE = 1000 @@ -578,7 +577,7 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType ) elif field_type is bool: return RediSearchFieldTypes.TAG - elif any(issubclass(field_type, t) for t in NUMERIC_TYPES): + elif is_numeric_type(field_type): # Index numeric Python types as NUMERIC fields, so we can support # range queries. return RediSearchFieldTypes.NUMERIC @@ -1805,7 +1804,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = cls.schema_for_type(name, embedded_cls, field_info) elif typ is bool: schema = f"{name} TAG" - elif any(issubclass(typ, t) for t in NUMERIC_TYPES): + elif is_numeric_type(typ): vector_options: Optional[VectorFieldOptions] = getattr( field_info, "vector_options", None ) @@ -2004,9 +2003,7 @@ def schema_for_type( field_info, "vector_options", None ) try: - is_vector = vector_options and any( - issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES - ) + is_vector = vector_options and has_numeric_inner_type(typ) except IndexError: raise RedisModelError( f"Vector field '{name}' must be annotated as a container type" @@ -2104,7 +2101,11 @@ def schema_for_type( # a proper type, we can pull the type information from the origin of the first argument. if not isinstance(typ, type): type_args = typing_get_args(field_info.annotation) - typ = type_args[0].__origin__ + typ = ( + getattr(type_args[0], "__origin__", type_args[0]) + if type_args + else typ + ) # TODO: GEO field if is_vector and vector_options: @@ -2127,7 +2128,7 @@ def schema_for_type( schema += " CASESENSITIVE" elif typ is bool: schema = f"{path} AS {index_field_name} TAG" - elif any(issubclass(typ, t) for t in NUMERIC_TYPES): + elif is_numeric_type(typ): schema = f"{path} AS {index_field_name} NUMERIC" elif issubclass(typ, str): if full_text_search is True: diff --git a/aredis_om/util.py b/aredis_om/util.py index 268657e..fc6a534 100644 --- a/aredis_om/util.py +++ b/aredis_om/util.py @@ -1,4 +1,6 @@ +import decimal import inspect +from typing import Any, Type, get_args def is_async_mode() -> bool: @@ -10,3 +12,27 @@ async def f() -> None: ASYNC_MODE = is_async_mode() + +NUMERIC_TYPES = (float, int, decimal.Decimal) + + +def is_numeric_type(type_: Type[Any]) -> bool: + try: + return issubclass(type_, NUMERIC_TYPES) + except TypeError: + return False + + +def has_numeric_inner_type(type_: Type[Any]) -> bool: + """ + Check if the type has a numeric inner type. + """ + args = get_args(type_) + + if not args: + return False + + try: + return issubclass(args[0], NUMERIC_TYPES) + except TypeError: + return False diff --git a/pyproject.toml b/pyproject.toml index f6b14b3..a672735 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-om" -version = "1.0.1-beta" +version = "1.0.2-beta" description = "Object mappings, and more, for Redis." authors = ["Redis OSS "] maintainers = ["Redis OSS "] @@ -22,6 +22,7 @@ classifiers = [ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Programming Language :: Python', ] include=[ diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 99bc36b..c3b578a 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -180,15 +180,15 @@ async def test_full_text_search_queries(members, m): async def test_pagination_queries(members, m): member1, member2, member3 = members - actual = await m.Member.find(m.Member.last_name == "Brookins").page() + actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("id").page() assert actual == [member1, member2] - actual = await m.Member.find().page(1, 1) + actual = await m.Member.find().sort_by("id").page(1, 1) assert actual == [member2] - actual = await m.Member.find().page(0, 1) + actual = await m.Member.find().sort_by("id").page(0, 1) assert actual == [member1] diff --git a/tests/test_json_model.py b/tests/test_json_model.py index d6428e1..44ae9c6 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -755,8 +755,10 @@ async def test_sorting(members, m): async def test_case_sensitive(members, m): member1, member2, member3 = members - actual = await m.Member.find(m.Member.first_name == "Andrew").all() - assert actual == [member1, member3] + actual = await m.Member.find(m.Member.first_name == "Andrew").sort_by("pk").all() + assert sorted(actual, key=lambda m: m.pk) == sorted( + [member1, member3], key=lambda m: m.pk + ) actual = await m.Member.find(m.Member.first_name == "andrew").all() assert actual == [] diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py index cea2e76..258e102 100644 --- a/tests/test_knn_expression.py +++ b/tests/test_knn_expression.py @@ -29,7 +29,24 @@ class Meta: class Member(BaseJsonModel, index=True): name: str - embeddings: list[list[float]] = Field([], vector_options=vector_field_options) + embeddings: list[float] = Field([], vector_options=vector_field_options) + embeddings_score: Optional[float] = None + + await Migrator().run() + + return Member + + +@pytest_asyncio.fixture +async def n(key_prefix, redis): + class BaseJsonModel(JsonModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + class Member(BaseJsonModel, index=True): + name: str + nested: list[list[float]] = Field([], vector_options=vector_field_options) embeddings_score: Optional[float] = None await Migrator().run() @@ -45,7 +62,7 @@ def to_bytes(vectors: list[float]) -> bytes: async def test_vector_field(m: Type[JsonModel]): # Create a new instance of the Member model vectors = [0.3 for _ in range(DIMENSIONS)] - member = m(name="seth", embeddings=[vectors]) + member = m(name="seth", embeddings=vectors) # Save the member to Redis await member.save() @@ -63,3 +80,27 @@ async def test_vector_field(m: Type[JsonModel]): assert len(members) == 1 assert members[0].embeddings_score is not None + + +@py_test_mark_asyncio +async def test_nested_vector_field(n: Type[JsonModel]): + # Create a new instance of the Member model + vectors = [0.3 for _ in range(DIMENSIONS)] + member = n(name="seth", nested=[vectors]) + + # Save the member to Redis + await member.save() + + knn = KNNExpression( + k=1, + vector_field=n.nested, + score_field=n.embeddings_score, + reference_vector=to_bytes(vectors), + ) + + query = n.find(knn=knn) + + members = await query.all() + + assert len(members) == 1 + assert members[0].embeddings_score is not None