From 7b718ace9abbe8a3c31bbba4a348c61ce92e273b Mon Sep 17 00:00:00 2001 From: Anurag Wagh Date: Thu, 3 Aug 2023 10:13:31 +0530 Subject: [PATCH 1/5] [mod] update `get_redis_connection` to allow redis cluster connection --- aredis_om/connections.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/aredis_om/connections.py b/aredis_om/connections.py index a8e693e2..d01882b3 100644 --- a/aredis_om/connections.py +++ b/aredis_om/connections.py @@ -1,12 +1,12 @@ import os +from typing import Union from . import redis - URL = os.environ.get("REDIS_OM_URL", None) -def get_redis_connection(**kwargs) -> redis.Redis: +def get_redis_connection(**kwargs) -> Union[redis.Redis, redis.RedisCluster]: # Decode from UTF-8 by default if "decode_responses" not in kwargs: kwargs["decode_responses"] = True @@ -14,7 +14,9 @@ def get_redis_connection(**kwargs) -> redis.Redis: # If someone passed in a 'url' parameter, or specified a REDIS_OM_URL # environment variable, we'll create the Redis client from the URL. url = kwargs.pop("url", URL) + cluster = kwargs.get("cluster", False) + conn_obj = redis.RedisCluster if cluster else redis.Redis if url: - return redis.Redis.from_url(url, **kwargs) + return conn_obj.from_url(url, **kwargs) - return redis.Redis(**kwargs) + return conn_obj(**kwargs) From b89719ac5f729a60ba4ced7d31fb39c333e56b42 Mon Sep 17 00:00:00 2001 From: Anurag Wagh Date: Thu, 3 Aug 2023 10:25:36 +0530 Subject: [PATCH 2/5] [mod] functionality to create indexes on Cluster --- aredis_om/model/migrations/migrator.py | 28 ++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/aredis_om/model/migrations/migrator.py b/aredis_om/model/migrations/migrator.py index 27cac65e..01d2dd51 100644 --- a/aredis_om/model/migrations/migrator.py +++ b/aredis_om/model/migrations/migrator.py @@ -2,14 +2,12 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import List, Optional +from typing import List, Optional, Union from ... import redis - log = logging.getLogger(__name__) - import importlib # noqa: E402 import pkgutil # noqa: E402 @@ -30,7 +28,7 @@ def import_submodules(root_module_name: str): ) for loader, module_name, is_pkg in pkgutil.walk_packages( - root_module.__path__, root_module.__name__ + "." # type: ignore + root_module.__path__, root_module.__name__ + "." # type: ignore ): importlib.import_module(module_name) @@ -39,7 +37,25 @@ def schema_hash_key(index_name): return f"{index_name}:hash" -async def create_index(conn: redis.Redis, index_name, schema, current_hash): +async def _create_index_cluster(conn: redis.RedisCluster, index_name, schema, current_hash): + """Create a search index on a Redis Cluster. + This is a workaround for the fact that the `FT.CREATE` command is not supported in Redis Cluster. + The implementation is same is `create_index` but with the following changes: + - `command` is passed as a list instead of a string + - The `FT.CREATE` command is executed only on primary nodes + """ + try: + await conn.ft(index_name).info() + except redis.ResponseError: + command = f"ft.create {index_name} {schema}".split() + await conn.execute_command(*command, target_nodes="primaries") + await conn.set(schema_hash_key(index_name), current_hash) # type: ignore + + +async def create_index(conn: [redis.Redis, redis.RedisCluster], index_name, schema, current_hash): + if type(conn) is redis.RedisCluster: + return _create_index_cluster(conn, index_name, schema, current_hash) + db_number = conn.connection_pool.connection_kwargs.get("db") if db_number and db_number > 0: raise MigrationError( @@ -68,7 +84,7 @@ class IndexMigration: schema: str hash: str action: MigrationAction - conn: redis.Redis + conn: Union[redis.Redis, redis.RedisCluster] previous_hash: Optional[str] = None async def run(self): From 31e06f6e7593839b2070f76e9a1fe264a55c8d2d Mon Sep 17 00:00:00 2001 From: Anurag Wagh Date: Thu, 3 Aug 2023 10:28:19 +0530 Subject: [PATCH 3/5] [mod] enhance type hints to support both Redis and RedisCluster --- aredis_om/model/model.py | 203 +++++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 102 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index a4c6b9e7..6e096524 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -50,7 +50,6 @@ from .render_tree import render_tree from .token_escaper import TokenEscaper - model_registry = {} _T = TypeVar("_T") Model = TypeVar("Model", bound="RedisModel") @@ -145,7 +144,7 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any def decode_redis_value( - obj: Union[List[bytes], Dict[bytes, bytes], bytes], encoding: str + obj: Union[List[bytes], Dict[bytes, bytes], bytes], encoding: str ) -> Union[List[str], Dict[str, str], str]: """Decode a binary-encoded Redis hash into the specified encoding.""" if isinstance(obj, list): @@ -162,7 +161,7 @@ def decode_redis_value( def remove_prefix(value: str, prefix: str) -> str: """Remove a prefix from a string.""" if value.startswith(prefix): - value = value[len(prefix) :] # noqa: E203 + value = value[len(prefix):] # noqa: E203 return value @@ -171,7 +170,7 @@ class PipelineError(Exception): def verify_pipeline_response( - response: List[Union[bytes, str]], expected_responses: int = 0 + response: List[Union[bytes, str]], expected_responses: int = 0 ): # TODO: More generic pipeline verification here (what else is possible?), # plus hash and JSON-specific verifications in separate functions. @@ -371,15 +370,15 @@ class RediSearchFieldTypes(Enum): class FindQuery: def __init__( - self, - expressions: Sequence[ExpressionOrNegated], - model: Type["RedisModel"], - knn: Optional[KNNExpression] = None, - offset: int = 0, - limit: Optional[int] = None, - page_size: int = DEFAULT_PAGE_SIZE, - sort_fields: Optional[List[str]] = None, - nocontent: bool = False, + self, + expressions: Sequence[ExpressionOrNegated], + model: Type["RedisModel"], + knn: Optional[KNNExpression] = None, + offset: int = 0, + limit: Optional[int] = None, + page_size: int = DEFAULT_PAGE_SIZE, + sort_fields: Optional[List[str]] = None, + nocontent: bool = False, ): if not has_redisearch(model.db()): raise RedisModelError( @@ -456,10 +455,10 @@ def query(self): self._query = self.resolve_redisearch_query(self.expression) if self.knn: self._query = ( - self._query - if self._query.startswith("(") or self._query == "*" - else f"({self._query})" - ) + f"=>[{self.knn}]" + self._query + if self._query.startswith("(") or self._query == "*" + else f"({self._query})" + ) + f"=>[{self.knn}]" return self._query @property @@ -560,13 +559,13 @@ def expand_tag_value(value): @classmethod def resolve_value( - cls, - field_name: str, - field_type: RediSearchFieldTypes, - field_info: PydanticFieldInfo, - op: Operators, - value: Any, - parents: List[Tuple[str, "RedisModel"]], + cls, + field_name: str, + field_type: RediSearchFieldTypes, + field_info: PydanticFieldInfo, + op: Operators, + value: Any, + parents: List[Tuple[str, "RedisModel"]], ) -> str: if parents: prefix = "_".join([p[0] for p in parents]) @@ -707,7 +706,7 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: return "*" if isinstance(expression.left, Expression) or isinstance( - expression.left, NegatedExpression + expression.left, NegatedExpression ): result += f"({cls.resolve_redisearch_query(expression.left)})" elif isinstance(expression.left, ModelField): @@ -963,11 +962,11 @@ def create_pk(*args, **kwargs) -> str: def __dataclass_transform__( - *, - eq_default: bool = True, - order_default: bool = False, - kw_only_default: bool = False, - field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), ) -> Callable[[_T], _T]: return lambda a: a @@ -989,10 +988,10 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: class RelationshipInfo(Representation): def __init__( - self, - *, - back_populates: Optional[str] = None, - link_model: Optional[Any] = None, + self, + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, ) -> None: self.back_populates = back_populates self.link_model = link_model @@ -1032,11 +1031,11 @@ class DISTANCE_METRIC(Enum): @staticmethod def flat( - type: TYPE, - dimension: int, - distance_metric: DISTANCE_METRIC, - initial_cap: Optional[int] = None, - block_size: Optional[int] = None, + type: TYPE, + dimension: int, + distance_metric: DISTANCE_METRIC, + initial_cap: Optional[int] = None, + block_size: Optional[int] = None, ): return VectorFieldOptions( algorithm=VectorFieldOptions.ALGORITHM.FLAT, @@ -1049,14 +1048,14 @@ def flat( @staticmethod def hnsw( - type: TYPE, - dimension: int, - distance_metric: DISTANCE_METRIC, - initial_cap: Optional[int] = None, - m: Optional[int] = None, - ef_construction: Optional[int] = None, - ef_runtime: Optional[int] = None, - epsilon: Optional[float] = None, + type: TYPE, + dimension: int, + distance_metric: DISTANCE_METRIC, + initial_cap: Optional[int] = None, + m: Optional[int] = None, + ef_construction: Optional[int] = None, + ef_runtime: Optional[int] = None, + epsilon: Optional[float] = None, ): return VectorFieldOptions( algorithm=VectorFieldOptions.ALGORITHM.HNSW, @@ -1087,36 +1086,36 @@ def schema(self): def Field( - default: Any = Undefined, - *, - default_factory: Optional[NoArgAnyCallable] = None, - alias: str = None, - title: str = None, - description: str = None, - exclude: Union[ - AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - include: Union[ - AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - const: bool = None, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, - min_items: int = None, - max_items: int = None, - min_length: int = None, - max_length: int = None, - allow_mutation: bool = True, - regex: str = None, - primary_key: bool = False, - sortable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - full_text_search: Union[bool, UndefinedType] = Undefined, - vector_options: Optional[VectorFieldOptions] = None, - schema_extra: Optional[Dict[str, Any]] = None, + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: str = None, + title: str = None, + description: str = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: bool = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + min_items: int = None, + max_items: int = None, + min_length: int = None, + max_length: int = None, + allow_mutation: bool = True, + regex: str = None, + primary_key: bool = False, + sortable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + full_text_search: Union[bool, UndefinedType] = Undefined, + vector_options: Optional[VectorFieldOptions] = None, + schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} field_info = FieldInfo( @@ -1160,7 +1159,7 @@ class BaseMeta(Protocol): global_key_prefix: str model_key_prefix: str primary_key_pattern: str - database: redis.Redis + database: Union[redis.Redis, redis.RedisCluster] primary_key: PrimaryKey primary_key_creator_cls: Type[PrimaryKeyCreator] index_name: str @@ -1179,7 +1178,7 @@ class DefaultMeta: global_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None primary_key_pattern: Optional[str] = None - database: Optional[redis.Redis] = None + database: Optional[Union[redis.Redis, redis.RedisCluster]] = None primary_key: Optional[PrimaryKey] = None primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None index_name: Optional[str] = None @@ -1309,7 +1308,7 @@ async def _delete(cls, db, *pks): @classmethod async def delete( - cls, pk: Any, pipeline: Optional[redis.client.Pipeline] = None + cls, pk: Any, pipeline: Optional[redis.client.Pipeline] = None ) -> int: """Delete data at this key.""" db = cls._get_db(pipeline) @@ -1325,12 +1324,12 @@ async def update(self, **field_values): raise NotImplementedError async def save( - self: "Model", pipeline: Optional[redis.client.Pipeline] = None + self: "Model", pipeline: Optional[redis.client.Pipeline] = None ) -> "Model": raise NotImplementedError async def expire( - self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None + self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None ): db = self._get_db(pipeline) @@ -1374,9 +1373,9 @@ def db(cls): @classmethod def find( - cls, - *expressions: Union[Any, Expression], - knn: Optional[KNNExpression] = None, + cls, + *expressions: Union[Any, Expression], + knn: Optional[KNNExpression] = None, ) -> FindQuery: return FindQuery(expressions=expressions, knn=knn, model=cls) @@ -1430,10 +1429,10 @@ def get_annotations(cls): @classmethod async def add( - cls: Type["Model"], - models: Sequence["Model"], - pipeline: Optional[redis.client.Pipeline] = None, - pipeline_verifier: Callable[..., Any] = verify_pipeline_response, + cls: Type["Model"], + models: Sequence["Model"], + pipeline: Optional[redis.client.Pipeline] = None, + pipeline_verifier: Callable[..., Any] = verify_pipeline_response, ) -> Sequence["Model"]: db = cls._get_db(pipeline, bulk=True) @@ -1451,7 +1450,7 @@ async def add( @classmethod def _get_db( - self, pipeline: Optional[redis.client.Pipeline] = None, bulk: bool = False + self, pipeline: Optional[redis.client.Pipeline] = None, bulk: bool = False ): if pipeline is not None: return pipeline @@ -1462,9 +1461,9 @@ def _get_db( @classmethod async def delete_many( - cls, - models: Sequence["RedisModel"], - pipeline: Optional[redis.client.Pipeline] = None, + cls, + models: Sequence["RedisModel"], + pipeline: Optional[redis.client.Pipeline] = None, ) -> int: db = cls._get_db(pipeline) @@ -1509,7 +1508,7 @@ def __init_subclass__(cls, **kwargs): ) async def save( - self: "Model", pipeline: Optional[redis.client.Pipeline] = None + self: "Model", pipeline: Optional[redis.client.Pipeline] = None ) -> "Model": self.check() db = self._get_db(pipeline) @@ -1683,7 +1682,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def save( - self: "Model", pipeline: Optional[redis.client.Pipeline] = None + self: "Model", pipeline: Optional[redis.client.Pipeline] = None ) -> "Model": self.check() db = self._get_db(pipeline) @@ -1756,13 +1755,13 @@ def schema_for_fields(cls): @classmethod def schema_for_type( - cls, - json_path: str, - name: str, - name_prefix: str, - typ: Any, - field_info: PydanticFieldInfo, - parent_type: Optional[Any] = None, + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Any, + field_info: PydanticFieldInfo, + parent_type: Optional[Any] = None, ) -> str: should_index = getattr(field_info, "index", False) is_container_type = is_supported_container_type(typ) From ed1525e2a5800ecfc390d8fc8e62f239648140aa Mon Sep 17 00:00:00 2001 From: Anurag Wagh Date: Thu, 3 Aug 2023 10:44:10 +0530 Subject: [PATCH 4/5] [refactor] use predefined flag to indicate primary clusters --- aredis_om/model/migrations/migrator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aredis_om/model/migrations/migrator.py b/aredis_om/model/migrations/migrator.py index 01d2dd51..8cdcc3df 100644 --- a/aredis_om/model/migrations/migrator.py +++ b/aredis_om/model/migrations/migrator.py @@ -48,7 +48,8 @@ async def _create_index_cluster(conn: redis.RedisCluster, index_name, schema, cu await conn.ft(index_name).info() except redis.ResponseError: command = f"ft.create {index_name} {schema}".split() - await conn.execute_command(*command, target_nodes="primaries") + + await conn.execute_command(*command, target_nodes=redis.RedisCluster.PRIMARIES) await conn.set(schema_hash_key(index_name), current_hash) # type: ignore From d911368757b408452ca6e150c7fef56e3a12cc56 Mon Sep 17 00:00:00 2001 From: Anurag Wagh Date: Thu, 3 Aug 2023 11:20:33 +0530 Subject: [PATCH 5/5] [mod] check if url contains cluster=true --- aredis_om/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aredis_om/connections.py b/aredis_om/connections.py index d01882b3..d637768e 100644 --- a/aredis_om/connections.py +++ b/aredis_om/connections.py @@ -14,7 +14,7 @@ def get_redis_connection(**kwargs) -> Union[redis.Redis, redis.RedisCluster]: # If someone passed in a 'url' parameter, or specified a REDIS_OM_URL # environment variable, we'll create the Redis client from the URL. url = kwargs.pop("url", URL) - cluster = kwargs.get("cluster", False) + cluster = kwargs.get("cluster", False) or "cluster=true" in str(url).lower() conn_obj = redis.RedisCluster if cluster else redis.Redis if url: return conn_obj.from_url(url, **kwargs)