From 76b3bc4756c177804a5e63b47fa5e896bdb02d0d Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Sun, 3 Oct 2021 23:01:39 +0300 Subject: [PATCH 1/5] query params --- redisearch/query.py | 29 ++++++++++++++++ test/test.py | 85 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/redisearch/query.py b/redisearch/query.py index f741a6d..75c1c94 100644 --- a/redisearch/query.py +++ b/redisearch/query.py @@ -1,3 +1,4 @@ +from typing import Dict, Union import six class Query(object): @@ -34,6 +35,7 @@ def __init__(self, query_string): self._summarize_fields = [] self._highlight_fields = [] self._language = None + self._params = {} def query_string(self): """ @@ -209,6 +211,14 @@ def get_args(self): args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] + + if len(self._params) > 0: + args.append("PARAMS") + args.append(len(self._params)*2) + for key, value in self._params.items(): + args.append(key) + args.append(value) + return args def paging(self, offset, num): @@ -288,6 +298,25 @@ def sort_by(self, field, asc=True): self._sortby = SortbyField(field, asc) return self + def set_params_dict(self, params: Dict[str, Union[str, int, float]]): + """" + Add a parameters dictionary. Overwrites an existing parameters dictionary. + + - **params** - Dict[str, Union[str, int, float]] dictionary + """ + self._params = params + return self + + def add_param(self, param_name:str, value:Union[str, int, float]): + """ + Adds a parameter to the parameters dictionary. + + - **param_name** - parmaeter name + - **value** - parameter value + """ + self._params[param_name] = value + return self + class Filter(object): diff --git a/test/test.py b/test/test.py index 6253a43..bf066f0 100644 --- a/test/test.py +++ b/test/test.py @@ -1189,6 +1189,91 @@ def testSearchReturnFields(self): self.assertEqual('doc:1', total[0].id) self.assertEqual('telmatosaurus', total[0].txt) + def test_text_params(self): + conn = self.redis() + + with conn as r: + # Creating a client with a given index name + client = Client('idx', port=conn.port) + client.redis.flushdb() + client.create_index((TextField('name'),)) + + client.add_document('doc1', name='Alice') + client.add_document('doc2', name='Bob') + client.add_document('doc3', name='Carol') + + + q = Query("@name:($name1 | $name2 )").set_params_dict({"name1":"Alice", "name2":"Bob"}) + res = client.search(q) + self.assertEqual(2, res.total) + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + + q = Query("@name:($name1 | $name2 )").add_param("name1", "Alice").add_param("name2", "Bob") + res = client.search(q) + self.assertEqual(2, res.total) + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + + + def test_numeric_params(self): + conn = self.redis() + + with conn as r: + # Creating a client with a given index name + client = Client('idx', port=conn.port) + client.redis.flushdb() + client.create_index((NumericField('numval'),)) + + client.add_document('doc1', numval=101) + client.add_document('doc2', numval=102) + client.add_document('doc3', numval=103) + + q = Query('@numval:[$min $max]').set_params_dict({"min":101, "max":102}) + res = client.search(q) + self.assertEqual(2, res.total) + + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + + q = Query('@numval:[$min $max]').add_param("min", 101).add_param("max", 102) + res = client.search(q) + self.assertEqual(2, res.total) + + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + + def test_geo_params(self): + conn = self.redis() + + with conn as r: + # Creating a client with a given index name + client = Client('idx', port=conn.port) + client.redis.flushdb() + client.create_index((GeoField('g'),)) + + client.add_document('doc1', g='29.69465, 34.95126') + client.add_document('doc2', g='29.69350, 34.94737') + client.add_document('doc3', g='29.68746, 34.94882') + + q = Query('@g:[$lon $lat $radius $units]').set_params_dict({"lat":'34.95126', "lon":'29.69465', "radius":10, "units":"km"}) + res = client.search(q) + self.assertEqual(3, res.total) + + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + self.assertEqual('doc3', res.docs[2].id) + + + q = Query('@g:[$lon $lat $radius $units]').add_param("lat", '34.95126').add_param("lon", '29.69465',).add_param("radius", 10).add_param("units", "km") + res = client.search(q) + self.assertEqual(3, res.total) + + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + self.assertEqual('doc3', res.docs[2].id) + + if __name__ == '__main__': unittest.main() From 0dc35341230ed453036af7744a6dd94fcfd4a744 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Mon, 4 Oct 2021 10:02:47 +0300 Subject: [PATCH 2/5] changed add_param to set_param --- redisearch/query.py | 2 +- test/test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/redisearch/query.py b/redisearch/query.py index 75c1c94..b12afc3 100644 --- a/redisearch/query.py +++ b/redisearch/query.py @@ -307,7 +307,7 @@ def set_params_dict(self, params: Dict[str, Union[str, int, float]]): self._params = params return self - def add_param(self, param_name:str, value:Union[str, int, float]): + def set_param(self, param_name:str, value:Union[str, int, float]): """ Adds a parameter to the parameters dictionary. diff --git a/test/test.py b/test/test.py index bf066f0..5dcff97 100644 --- a/test/test.py +++ b/test/test.py @@ -1209,7 +1209,7 @@ def test_text_params(self): self.assertEqual('doc1', res.docs[0].id) self.assertEqual('doc2', res.docs[1].id) - q = Query("@name:($name1 | $name2 )").add_param("name1", "Alice").add_param("name2", "Bob") + q = Query("@name:($name1 | $name2 )").set_param("name1", "Alice").set_param("name2", "Bob") res = client.search(q) self.assertEqual(2, res.total) self.assertEqual('doc1', res.docs[0].id) @@ -1236,7 +1236,7 @@ def test_numeric_params(self): self.assertEqual('doc1', res.docs[0].id) self.assertEqual('doc2', res.docs[1].id) - q = Query('@numval:[$min $max]').add_param("min", 101).add_param("max", 102) + q = Query('@numval:[$min $max]').set_param("min", 101).set_param("max", 102) res = client.search(q) self.assertEqual(2, res.total) @@ -1265,7 +1265,7 @@ def test_geo_params(self): self.assertEqual('doc3', res.docs[2].id) - q = Query('@g:[$lon $lat $radius $units]').add_param("lat", '34.95126').add_param("lon", '29.69465',).add_param("radius", 10).add_param("units", "km") + q = Query('@g:[$lon $lat $radius $units]').set_param("lat", '34.95126').set_param("lon", '29.69465',).set_param("radius", 10).set_param("units", "km") res = client.search(q) self.assertEqual(3, res.total) From 14586ea7c0571e171e246383f0d1d5ab7c38bd0a Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Mon, 4 Oct 2021 12:21:27 +0300 Subject: [PATCH 3/5] search params --- redisearch/client.py | 14 ++++++++------ test/test.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/redisearch/client.py b/redisearch/client.py index 80971b4..858bba2 100644 --- a/redisearch/client.py +++ b/redisearch/client.py @@ -1,3 +1,4 @@ +from typing import Dict, Union from redis import Redis, ConnectionPool import itertools import time @@ -501,7 +502,7 @@ def info(self): it = six.moves.map(to_string, res) return dict(six.moves.zip(it, it)) - def _mk_query_args(self, query): + def _mk_query_args(self, query, query_params): args = [self.index_name] if isinstance(query, six.string_types): @@ -509,11 +510,12 @@ def _mk_query_args(self, query): query = Query(query) if not isinstance(query, Query): raise ValueError("Bad query type %s" % type(query)) - + if query_params is not None: + query.set_params_dict(query_params) args += query.get_args() return args, query - def search(self, query): + def search(self, query, query_params: Dict[str, Union[str, int, float]] = None): """ Search the index for a given query, and return a result of documents @@ -522,7 +524,7 @@ def search(self, query): - **query**: the search query. Either a text for simple queries with default parameters, or a Query object for complex queries. See RediSearch's documentation on query format """ - args, query = self._mk_query_args(query) + args, query = self._mk_query_args(query, query_params=query_params) st = time.time() res = self.redis.execute_command(self.SEARCH_CMD, *args) @@ -532,8 +534,8 @@ def search(self, query): has_payload=query._with_payloads, with_scores=query._with_scores) - def explain(self, query): - args, query_text = self._mk_query_args(query) + def explain(self, query, query_params: Dict[str, Union[str, int, float]] = None): + args, query_text = self._mk_query_args(query, query_params=query_params) return self.redis.execute_command(self.EXPLAIN_CMD, *args) def aggregate(self, query): diff --git a/test/test.py b/test/test.py index 5dcff97..f2bf84f 100644 --- a/test/test.py +++ b/test/test.py @@ -1202,8 +1202,8 @@ def test_text_params(self): client.add_document('doc2', name='Bob') client.add_document('doc3', name='Carol') - - q = Query("@name:($name1 | $name2 )").set_params_dict({"name1":"Alice", "name2":"Bob"}) + params_dict = {"name1":"Alice", "name2":"Bob"} + q = Query("@name:($name1 | $name2 )").set_params_dict(params=params_dict) res = client.search(q) self.assertEqual(2, res.total) self.assertEqual('doc1', res.docs[0].id) @@ -1215,6 +1215,12 @@ def test_text_params(self): self.assertEqual('doc1', res.docs[0].id) self.assertEqual('doc2', res.docs[1].id) + q = Query("@name:($name1 | $name2 )") + res = client.search(q, query_params=params_dict) + self.assertEqual(2, res.total) + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + def test_numeric_params(self): conn = self.redis() @@ -1229,7 +1235,8 @@ def test_numeric_params(self): client.add_document('doc2', numval=102) client.add_document('doc3', numval=103) - q = Query('@numval:[$min $max]').set_params_dict({"min":101, "max":102}) + params_dict = {"min":101, "max":102} + q = Query('@numval:[$min $max]').set_params_dict(params=params_dict) res = client.search(q) self.assertEqual(2, res.total) @@ -1243,6 +1250,13 @@ def test_numeric_params(self): self.assertEqual('doc1', res.docs[0].id) self.assertEqual('doc2', res.docs[1].id) + q = Query('@numval:[$min $max]') + res = client.search(q, query_params=params_dict) + self.assertEqual(2, res.total) + + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + def test_geo_params(self): conn = self.redis() @@ -1256,7 +1270,9 @@ def test_geo_params(self): client.add_document('doc2', g='29.69350, 34.94737') client.add_document('doc3', g='29.68746, 34.94882') - q = Query('@g:[$lon $lat $radius $units]').set_params_dict({"lat":'34.95126', "lon":'29.69465', "radius":10, "units":"km"}) + params_dict = {"lat":'34.95126', "lon":'29.69465', "radius":10, "units":"km"} + + q = Query('@g:[$lon $lat $radius $units]').set_params_dict(params=params_dict) res = client.search(q) self.assertEqual(3, res.total) @@ -1273,6 +1289,14 @@ def test_geo_params(self): self.assertEqual('doc2', res.docs[1].id) self.assertEqual('doc3', res.docs[2].id) + q = Query('@g:[$lon $lat $radius $units]') + res = client.search(q, query_params=params_dict) + self.assertEqual(3, res.total) + + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + self.assertEqual('doc3', res.docs[2].id) + if __name__ == '__main__': From f3bbad6fec054ac523782807ec9c0b15639801ec Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Mon, 4 Oct 2021 13:17:41 +0300 Subject: [PATCH 4/5] moved getting params argument as part to class method --- redisearch/client.py | 9 +++++---- redisearch/query.py | 18 ++++++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/redisearch/client.py b/redisearch/client.py index 858bba2..cbb58f3 100644 --- a/redisearch/client.py +++ b/redisearch/client.py @@ -510,9 +510,9 @@ def _mk_query_args(self, query, query_params): query = Query(query) if not isinstance(query, Query): raise ValueError("Bad query type %s" % type(query)) - if query_params is not None: - query.set_params_dict(query_params) args += query.get_args() + if query_params is not None: + args+= Query.get_params_args(query_params) return args, query def search(self, query, query_params: Dict[str, Union[str, int, float]] = None): @@ -538,7 +538,7 @@ def explain(self, query, query_params: Dict[str, Union[str, int, float]] = None) args, query_text = self._mk_query_args(query, query_params=query_params) return self.redis.execute_command(self.EXPLAIN_CMD, *args) - def aggregate(self, query): + def aggregate(self, query, query_params: Dict[str, Union[str, int, float]] = None): """ Issue an aggregation query @@ -558,7 +558,8 @@ def aggregate(self, query): self.index_name] + query.build_args() else: raise ValueError('Bad query', query) - + if query_params is not None: + cmd+= Query.get_params_args(query_params) raw = self.redis.execute_command(*cmd) if has_cursor: if isinstance(query, Cursor): diff --git a/redisearch/query.py b/redisearch/query.py index b12afc3..1b2d8e6 100644 --- a/redisearch/query.py +++ b/redisearch/query.py @@ -149,6 +149,17 @@ def scorer(self, scorer): self._scorer = scorer return self + @staticmethod + def get_params_args(params: Dict[str, Union[str, int, float]]): + args = [] + if len(params) > 0: + args.append("PARAMS") + args.append(len(params)*2) + for key, value in params.items(): + args.append(key) + args.append(value) + return args + def get_args(self): """ Format the redis arguments for this query and return them @@ -212,12 +223,7 @@ def get_args(self): args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] - if len(self._params) > 0: - args.append("PARAMS") - args.append(len(self._params)*2) - for key, value in self._params.items(): - args.append(key) - args.append(value) + args+= Query.get_params_args(self._params) return args From a2b2796fffd435ef3b93f09f2e5b0349012e6b53 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Mon, 4 Oct 2021 15:12:04 +0300 Subject: [PATCH 5/5] removed parameter from query builder pattern --- redisearch/client.py | 14 ++++++++++++-- redisearch/query.py | 35 --------------------------------- test/test.py | 46 -------------------------------------------- 3 files changed, 12 insertions(+), 83 deletions(-) diff --git a/redisearch/client.py b/redisearch/client.py index cbb58f3..fda9ecf 100644 --- a/redisearch/client.py +++ b/redisearch/client.py @@ -502,6 +502,16 @@ def info(self): it = six.moves.map(to_string, res) return dict(six.moves.zip(it, it)) + def get_params_args(self, params: Dict[str, Union[str, int, float]]): + args = [] + if len(params) > 0: + args.append("PARAMS") + args.append(len(params)*2) + for key, value in params.items(): + args.append(key) + args.append(value) + return args + def _mk_query_args(self, query, query_params): args = [self.index_name] @@ -512,7 +522,7 @@ def _mk_query_args(self, query, query_params): raise ValueError("Bad query type %s" % type(query)) args += query.get_args() if query_params is not None: - args+= Query.get_params_args(query_params) + args+= self.get_params_args(query_params) return args, query def search(self, query, query_params: Dict[str, Union[str, int, float]] = None): @@ -559,7 +569,7 @@ def aggregate(self, query, query_params: Dict[str, Union[str, int, float]] = Non else: raise ValueError('Bad query', query) if query_params is not None: - cmd+= Query.get_params_args(query_params) + cmd+= self.get_params_args(query_params) raw = self.redis.execute_command(*cmd) if has_cursor: if isinstance(query, Cursor): diff --git a/redisearch/query.py b/redisearch/query.py index 1b2d8e6..e03d2a4 100644 --- a/redisearch/query.py +++ b/redisearch/query.py @@ -1,4 +1,3 @@ -from typing import Dict, Union import six class Query(object): @@ -35,7 +34,6 @@ def __init__(self, query_string): self._summarize_fields = [] self._highlight_fields = [] self._language = None - self._params = {} def query_string(self): """ @@ -149,17 +147,6 @@ def scorer(self, scorer): self._scorer = scorer return self - @staticmethod - def get_params_args(params: Dict[str, Union[str, int, float]]): - args = [] - if len(params) > 0: - args.append("PARAMS") - args.append(len(params)*2) - for key, value in params.items(): - args.append(key) - args.append(value) - return args - def get_args(self): """ Format the redis arguments for this query and return them @@ -223,8 +210,6 @@ def get_args(self): args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] - args+= Query.get_params_args(self._params) - return args def paging(self, offset, num): @@ -304,26 +289,6 @@ def sort_by(self, field, asc=True): self._sortby = SortbyField(field, asc) return self - def set_params_dict(self, params: Dict[str, Union[str, int, float]]): - """" - Add a parameters dictionary. Overwrites an existing parameters dictionary. - - - **params** - Dict[str, Union[str, int, float]] dictionary - """ - self._params = params - return self - - def set_param(self, param_name:str, value:Union[str, int, float]): - """ - Adds a parameter to the parameters dictionary. - - - **param_name** - parmaeter name - - **value** - parameter value - """ - self._params[param_name] = value - return self - - class Filter(object): def __init__(self, keyword, field, *args): diff --git a/test/test.py b/test/test.py index f2bf84f..f2b9f4c 100644 --- a/test/test.py +++ b/test/test.py @@ -1203,18 +1203,6 @@ def test_text_params(self): client.add_document('doc3', name='Carol') params_dict = {"name1":"Alice", "name2":"Bob"} - q = Query("@name:($name1 | $name2 )").set_params_dict(params=params_dict) - res = client.search(q) - self.assertEqual(2, res.total) - self.assertEqual('doc1', res.docs[0].id) - self.assertEqual('doc2', res.docs[1].id) - - q = Query("@name:($name1 | $name2 )").set_param("name1", "Alice").set_param("name2", "Bob") - res = client.search(q) - self.assertEqual(2, res.total) - self.assertEqual('doc1', res.docs[0].id) - self.assertEqual('doc2', res.docs[1].id) - q = Query("@name:($name1 | $name2 )") res = client.search(q, query_params=params_dict) self.assertEqual(2, res.total) @@ -1236,24 +1224,9 @@ def test_numeric_params(self): client.add_document('doc3', numval=103) params_dict = {"min":101, "max":102} - q = Query('@numval:[$min $max]').set_params_dict(params=params_dict) - res = client.search(q) - self.assertEqual(2, res.total) - - self.assertEqual('doc1', res.docs[0].id) - self.assertEqual('doc2', res.docs[1].id) - - q = Query('@numval:[$min $max]').set_param("min", 101).set_param("max", 102) - res = client.search(q) - self.assertEqual(2, res.total) - - self.assertEqual('doc1', res.docs[0].id) - self.assertEqual('doc2', res.docs[1].id) - q = Query('@numval:[$min $max]') res = client.search(q, query_params=params_dict) self.assertEqual(2, res.total) - self.assertEqual('doc1', res.docs[0].id) self.assertEqual('doc2', res.docs[1].id) @@ -1271,28 +1244,9 @@ def test_geo_params(self): client.add_document('doc3', g='29.68746, 34.94882') params_dict = {"lat":'34.95126', "lon":'29.69465', "radius":10, "units":"km"} - - q = Query('@g:[$lon $lat $radius $units]').set_params_dict(params=params_dict) - res = client.search(q) - self.assertEqual(3, res.total) - - self.assertEqual('doc1', res.docs[0].id) - self.assertEqual('doc2', res.docs[1].id) - self.assertEqual('doc3', res.docs[2].id) - - - q = Query('@g:[$lon $lat $radius $units]').set_param("lat", '34.95126').set_param("lon", '29.69465',).set_param("radius", 10).set_param("units", "km") - res = client.search(q) - self.assertEqual(3, res.total) - - self.assertEqual('doc1', res.docs[0].id) - self.assertEqual('doc2', res.docs[1].id) - self.assertEqual('doc3', res.docs[2].id) - q = Query('@g:[$lon $lat $radius $units]') res = client.search(q, query_params=params_dict) self.assertEqual(3, res.total) - self.assertEqual('doc1', res.docs[0].id) self.assertEqual('doc2', res.docs[1].id) self.assertEqual('doc3', res.docs[2].id)