Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions redis/commands/search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import redis
from redis.client import Pipeline as RedisPipeline

from ...asyncio.client import Pipeline as AsyncioPipeline
from .commands import (
Expand Down Expand Up @@ -181,9 +181,17 @@ def pipeline(self, transaction=True, shard_hint=None):
return p


class Pipeline(SearchCommands, redis.client.Pipeline):
class Pipeline(SearchCommands, RedisPipeline):
"""Pipeline for the module."""

def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
super().__init__(connection_pool, response_callbacks, transaction, shard_hint)
self.index_name: str = ""

class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):

class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline):
"""AsyncPipeline for the module."""

def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
super().__init__(connection_pool, response_callbacks, transaction, shard_hint)
self.index_name: str = ""
50 changes: 24 additions & 26 deletions redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Optional, Tuple, Union

from redis.commands.search.dialect import DEFAULT_DIALECT

Expand Down Expand Up @@ -27,9 +27,9 @@ class Reducer:
NAME = None

def __init__(self, *args: str) -> None:
self._args = args
self._field = None
self._alias = None
self._args: Tuple[str, ...] = args
self._field: Optional[str] = None
self._alias: Optional[str] = None

def alias(self, alias: str) -> "Reducer":
"""
Expand All @@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer":
if alias is FIELDNAME:
if not self._field:
raise ValueError("Cannot use FIELDNAME alias with no field")
# Chop off initial '@'
alias = self._field[1:]
else:
# Chop off initial '@'
alias = self._field[1:]
self._alias = alias
return self

@property
def args(self) -> List[str]:
def args(self) -> Tuple[str, ...]:
return self._args


Expand All @@ -64,7 +65,7 @@ class SortDirection:
This special class is used to indicate sort direction.
"""

DIRSTRING = None
DIRSTRING: Optional[str] = None

def __init__(self, field: str) -> None:
self.field = field
Expand Down Expand Up @@ -104,17 +105,17 @@ def __init__(self, query: str = "*") -> None:
All member methods (except `build_args()`)
return the object itself, making them useful for chaining.
"""
self._query = query
self._aggregateplan = []
self._loadfields = []
self._loadall = False
self._max = 0
self._with_schema = False
self._verbatim = False
self._cursor = []
self._dialect = DEFAULT_DIALECT
self._add_scores = False
self._scorer = "TFIDF"
self._query: str = query
self._aggregateplan: List[str] = []
self._loadfields: List[str] = []
self._loadall: bool = False
self._max: int = 0
self._with_schema: bool = False
self._verbatim: bool = False
self._cursor: List[str] = []
self._dialect: int = DEFAULT_DIALECT
self._add_scores: bool = False
self._scorer: str = "TFIDF"

def load(self, *fields: str) -> "AggregateRequest":
"""
Expand All @@ -133,7 +134,7 @@ def load(self, *fields: str) -> "AggregateRequest":
return self

def group_by(
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
self, fields: Union[str, List[str]], *reducers: Reducer
) -> "AggregateRequest":
"""
Specify by which fields to group the aggregation.
Expand All @@ -147,7 +148,6 @@ def group_by(
`aggregation` module.
"""
fields = [fields] if isinstance(fields, str) else fields
reducers = [reducers] if isinstance(reducers, Reducer) else reducers

ret = ["GROUPBY", str(len(fields)), *fields]
for reducer in reducers:
Expand Down Expand Up @@ -251,12 +251,10 @@ def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
.sort_by(Desc("@paid"), max=10)
```
"""
if isinstance(fields, (str, SortDirection)):
fields = [fields]

fields_args = []
for f in fields:
if isinstance(f, SortDirection):
if isinstance(f, (Asc, Desc)):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]
Expand Down Expand Up @@ -356,7 +354,7 @@ def build_args(self) -> List[str]:
ret.extend(self._loadfields)

if self._dialect:
ret.extend(["DIALECT", self._dialect])
ret.extend(["DIALECT", str(self._dialect)])

ret.extend(self._aggregateplan)

Expand Down Expand Up @@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None:
self.cursor = cursor
self.schema = schema

def __repr__(self) -> (str, str):
def __repr__(self) -> str:
cid = self.cursor.cid if self.cursor else -1
return (
f"<{self.__class__.__name__} at 0x{id(self):x} "
Expand Down
58 changes: 48 additions & 10 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,44 @@
class SearchCommands:
"""Search commands."""

@property
def index_name(self) -> str:
"""The name of the search index. Must be implemented by inheriting classes."""
if not hasattr(self, "_index_name"):
raise AttributeError("index_name must be set by the inheriting class")
return self._index_name

@index_name.setter
def index_name(self, value: str) -> None:
"""Set the name of the search index."""
self._index_name = value

@property
def client(self):
"""The Redis client. Must be provided by inheriting classes."""
if not hasattr(self, "_client"):
raise AttributeError("client must be set by the inheriting class")
return self._client

@client.setter
def client(self, value) -> None:
"""Set the Redis client."""
self._client = value

@property
def _RESP2_MODULE_CALLBACKS(self):
"""Response callbacks for RESP2. Must be provided by inheriting classes."""
if not hasattr(self, "_resp2_module_callbacks"):
raise AttributeError(
"_RESP2_MODULE_CALLBACKS must be set by the inheriting class"
)
return self._resp2_module_callbacks

@_RESP2_MODULE_CALLBACKS.setter
def _RESP2_MODULE_CALLBACKS(self, value) -> None:
"""Set the RESP2 module callbacks."""
self._resp2_module_callbacks = value

def _parse_results(self, cmd, res, **kwargs):
if get_protocol_version(self.client) in ["3", 3]:
return ProfileInformation(res) if cmd == "FT.PROFILE" else res
Expand Down Expand Up @@ -221,7 +259,7 @@ def create_index(

return self.execute_command(*args)

def alter_schema_add(self, fields: List[str]):
def alter_schema_add(self, fields: Union[Field, List[Field]]):
"""
Alter the existing search index by adding new fields. The index
must already exist.
Expand Down Expand Up @@ -336,11 +374,11 @@ def add_document(
doc_id: str,
nosave: bool = False,
score: float = 1.0,
payload: bool = None,
payload: Optional[bool] = None,
replace: bool = False,
partial: bool = False,
language: Optional[str] = None,
no_create: str = False,
no_create: bool = False,
**fields: List[str],
):
"""
Expand Down Expand Up @@ -464,7 +502,7 @@ def info(self):
return self._parse_results(INFO_CMD, res)

def get_params_args(
self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
):
if query_params is None:
return []
Expand All @@ -478,7 +516,7 @@ def get_params_args(
return args

def _mk_query_args(
self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
self, query, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
):
args = [self.index_name]

Expand Down Expand Up @@ -528,7 +566,7 @@ def search(
def explain(
self,
query: Union[str, Query],
query_params: Dict[str, Union[str, int, float]] = None,
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
):
"""Returns the execution plan for a complex query.

Expand All @@ -543,7 +581,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa
def aggregate(
self,
query: Union[AggregateRequest, Cursor],
query_params: Dict[str, Union[str, int, float]] = None,
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
):
"""
Issue an aggregation query.
Expand Down Expand Up @@ -598,7 +636,7 @@ def profile(
self,
query: Union[Query, AggregateRequest],
limited: bool = False,
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
):
"""
Performs a search or aggregate command and collects performance
Expand Down Expand Up @@ -936,7 +974,7 @@ async def info(self):
async def search(
self,
query: Union[str, Query],
query_params: Dict[str, Union[str, int, float]] = None,
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
):
"""
Search the index for a given query, and return a result of documents
Expand Down Expand Up @@ -968,7 +1006,7 @@ async def search(
async def aggregate(
self,
query: Union[AggregateResult, Cursor],
query_params: Dict[str, Union[str, int, float]] = None,
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
):
"""
Issue an aggregation query.
Expand Down
24 changes: 12 additions & 12 deletions redis/commands/search/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from redis.commands.search.dialect import DEFAULT_DIALECT

Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None:
self._with_scores: bool = False
self._scorer: Optional[str] = None
self._filters: List = list()
self._ids: Optional[List[str]] = None
self._ids: Optional[Tuple[str, ...]] = None
self._slop: int = -1
self._timeout: Optional[float] = None
self._in_order: bool = False
Expand Down Expand Up @@ -81,7 +81,7 @@ def return_field(
self._return_fields += ("AS", as_field)
return self

def _mk_field_list(self, fields: List[str]) -> List:
def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List:
if not fields:
return []
return [fields] if isinstance(fields, str) else list(fields)
Expand Down Expand Up @@ -126,7 +126,7 @@ def summarize(

def highlight(
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
) -> None:
) -> "Query":
"""
Apply specified markup to matched term(s) within the returned field(s).

Expand Down Expand Up @@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query":
self._scorer = scorer
return self

def get_args(self) -> List[str]:
def get_args(self) -> List[Union[str, int, float]]:
"""Format the redis arguments for this query and return them."""
args = [self._query_string]
args: List[Union[str, int, float]] = [self._query_string]
args += self._get_args_tags()
args += self._summarize_fields + self._highlight_fields
args += ["LIMIT", self._offset, self._num]
return args

def _get_args_tags(self) -> List[str]:
args = []
def _get_args_tags(self) -> List[Union[str, int, float]]:
args: List[Union[str, int, float]] = []
if self._no_content:
args.append("NOCONTENT")
if self._fields:
Expand Down Expand Up @@ -288,14 +288,14 @@ def with_scores(self) -> "Query":
self._with_scores = True
return self

def limit_fields(self, *fields: List[str]) -> "Query":
def limit_fields(self, *fields: str) -> "Query":
"""
Limit the search to specific TEXT fields only.

- **fields**: A list of strings; case-sensitive field names
- **fields**: Each element should be a string, case sensitive field name
from the defined schema.
"""
self._fields = fields
self._fields = list(fields)
return self

def add_filter(self, flt: "Filter") -> "Query":
Expand Down Expand Up @@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query":


class Filter:
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None:
self.args = [keyword, field] + list(args)


Expand Down
Loading