diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index 127141f650..9e789c03a9 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -59,45 +59,6 @@ def parse_to_list(response): return res -def parse_list_to_dict(response): - res = {} - for i in range(0, len(response), 2): - if isinstance(response[i], list): - res["Child iterators"].append(parse_list_to_dict(response[i])) - try: - if isinstance(response[i + 1], list): - res["Child iterators"].append(parse_list_to_dict(response[i + 1])) - except IndexError: - pass - elif isinstance(response[i + 1], list): - res["Child iterators"] = [parse_list_to_dict(response[i + 1])] - else: - try: - res[response[i]] = float(response[i + 1]) - except (TypeError, ValueError): - res[response[i]] = response[i + 1] - return res - - -def parse_to_dict(response): - if response is None: - return {} - - res = {} - for det in response: - if isinstance(det[1], list): - res[det[0]] = parse_list_to_dict(det[1]) - else: - try: # try to set the attribute. may be provided without value - try: # try to convert the value to float - res[det[0]] = float(det[1]) - except (TypeError, ValueError): - res[det[0]] = det[1] - except IndexError: - pass - return res - - def random_string(length=10): """ Returns a random N character long string. diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 2df2b5a754..3de537f189 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,11 +1,12 @@ import itertools import time -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from redis.client import Pipeline from redis.utils import deprecated_function -from ..helpers import get_protocol_version, parse_to_dict +from ..._parsers.helpers import pairs_to_dict +from ..helpers import get_protocol_version from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor from .document import Document @@ -85,7 +86,9 @@ def _parse_search(self, res, **kwargs): def _parse_aggregate(self, res, **kwargs): return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) - def _parse_profile(self, res, **kwargs): + def _parse_profile( + self, res, **kwargs + ) -> Tuple[Union[AggregateResult, Result], Dict[str, Any]]: query = kwargs["query"] if isinstance(query, AggregateRequest): result = self._get_aggregate_result(res[0], query, query._cursor) @@ -98,7 +101,22 @@ def _parse_profile(self, res, **kwargs): with_scores=query._with_scores, ) - return result, parse_to_dict(res[1]) + details = pairs_to_dict(res[1]) + details["Coordinator"] = pairs_to_dict(details["Coordinator"]) + details["Shards"] = [pairs_to_dict(s) for s in details["Shards"]] + for shard in details["Shards"]: + if "Iterators profile" in shard: + shard["Iterators profile"] = pairs_to_dict(shard["Iterators profile"]) + if "Child iterators" in shard["Iterators profile"]: + shard["Iterators profile"]["Child iterators"] = [ + pairs_to_dict(c) + for c in shard["Iterators profile"]["Child iterators"] + ] + if "Result processors profile" in shard: + shard["Result processors profile"] = [ + pairs_to_dict(r) for r in shard["Result processors profile"] + ] + return result, details def _parse_spellcheck(self, res, **kwargs): corrections = {} diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 66ee1c5390..06265d382e 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -4,7 +4,6 @@ delist, list_or_args, nativestr, - parse_to_dict, parse_to_list, quote_string, random_string, @@ -26,40 +25,6 @@ def test_parse_to_list(): assert parse_to_list(r) == ["hello", "my name", 45, 555.55, "is simon!", None] -def test_parse_to_dict(): - assert parse_to_dict(None) == {} - r = [ - ["Some number", "1.0345"], - ["Some string", "hello"], - [ - "Child iterators", - [ - "Time", - "0.2089", - "Counter", - 3, - "Child iterators", - ["Type", "bar", "Time", "0.0729", "Counter", 3], - ["Type", "barbar", "Time", "0.058", "Counter", 3], - ["Type", "barbarbar", "Time", "0.0234", "Counter", 3], - ], - ], - ] - assert parse_to_dict(r) == { - "Child iterators": { - "Child iterators": [ - {"Counter": 3.0, "Time": 0.0729, "Type": "bar"}, - {"Counter": 3.0, "Time": 0.058, "Type": "barbar"}, - {"Counter": 3.0, "Time": 0.0234, "Type": "barbarbar"}, - ], - "Counter": 3.0, - "Time": 0.2089, - }, - "Some number": 1.0345, - "Some string": "hello", - } - - def test_nativestr(): assert nativestr("teststr") == "teststr" assert nativestr(b"teststr") == "teststr" diff --git a/tests/test_search.py b/tests/test_search.py index bfe204254c..c6ad9eb958 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -108,12 +108,21 @@ def createIndex(client, num_docs=100, definition=None): @pytest.fixture def client(request): - r = _get_client(redis.Redis, request, decode_responses=True) + if hasattr(request, "param"): + protocol = request.param.get("protocol", None) + else: + protocol = None + r = _get_client(redis.Redis, request, decode_responses=True, protocol=protocol) r.flushdb() return r @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_client(client): num_docs = 500 createIndex(client.ft(), num_docs=num_docs) @@ -311,6 +320,11 @@ def test_client(client): @pytest.mark.redismod @pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_scores(client): client.ft().create_index((TextField("txt"),)) @@ -332,6 +346,11 @@ def test_scores(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_stopwords(client): client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) client.hset("doc1", mapping={"txt": "foo bar"}) @@ -350,6 +369,11 @@ def test_stopwords(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_filters(client): client.ft().create_index((TextField("txt"), NumericField("num"), GeoField("loc"))) client.hset( @@ -403,6 +427,11 @@ def test_filters(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_sort_by(client): client.ft().create_index((TextField("txt"), NumericField("num", sortable=True))) client.hset("doc1", mapping={"txt": "foo bar", "num": 1}) @@ -525,6 +554,11 @@ def test_auto_complete(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_no_index(client): client.ft().create_index( ( @@ -616,6 +650,11 @@ def test_explaincli(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_summarize(client): createIndex(client.ft()) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) @@ -660,6 +699,11 @@ def test_summarize(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_alias(client): index1 = getClient(client) index2 = getClient(client) @@ -723,6 +767,11 @@ def test_alias(client): @pytest.mark.redismod @pytest.mark.xfail(strict=False) +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_alias_basic(client): # Creating a client with one index index1 = getClient(client).ft("testAlias") @@ -771,6 +820,11 @@ def test_alias_basic(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_textfield_sortable_nostem(client): # Creating the index definition with sortable and no_stem client.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) @@ -786,6 +840,11 @@ def test_textfield_sortable_nostem(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_alter_schema_add(client): # Creating the index definition and schema client.ft().create_index(TextField("title")) @@ -810,6 +869,11 @@ def test_alter_schema_add(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_spell_check(client): client.ft().create_index((TextField("f1"), TextField("f2"))) @@ -879,6 +943,11 @@ def test_spell_check(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_dict_operations(client): client.ft().create_index((TextField("f1"), TextField("f2"))) # Add three items @@ -898,6 +967,11 @@ def test_dict_operations(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_phonetic_matcher(client): client.ft().create_index((TextField("name"),)) client.hset("doc1", mapping={"name": "Jon"}) @@ -931,6 +1005,11 @@ def test_phonetic_matcher(client): @pytest.mark.redismod @pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_scorer(client): client.ft().create_index((TextField("description"),)) @@ -1015,6 +1094,11 @@ def test_config(client): @pytest.mark.redismod @pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_aggregations_groupby(client): # Creating the index definition and schema client.ft().create_index( @@ -1264,6 +1348,11 @@ def test_aggregations_groupby(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_aggregations_sort_by_and_limit(client): client.ft().create_index((TextField("t1"), TextField("t2"))) @@ -1323,6 +1412,11 @@ def test_aggregations_sort_by_and_limit(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_aggregations_load(client): client.ft().create_index((TextField("t1"), TextField("t2"))) @@ -1361,6 +1455,11 @@ def test_aggregations_load(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_aggregations_apply(client): client.ft().create_index( ( @@ -1396,6 +1495,11 @@ def test_aggregations_apply(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_aggregations_filter(client): client.ft().create_index( (TextField("name", sortable=True), NumericField("age", sortable=True)) @@ -1499,6 +1603,11 @@ def test_expire(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_skip_initial_scan(client): client.hset("doc1", "foo", "bar") q = Query("@foo:bar") @@ -1584,6 +1693,11 @@ def test_create_client_definition_hash(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.2.0", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_create_client_definition_json(client): """ Create definition with IndexType.JSON as index type (ON JSON), @@ -1609,6 +1723,11 @@ def test_create_client_definition_json(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.2.0", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_fields_as_name(client): # create index SCHEMA = ( @@ -1636,6 +1755,11 @@ def test_fields_as_name(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_casesensitive(client): # create index SCHEMA = (TagField("t", case_sensitive=False),) @@ -1671,6 +1795,11 @@ def test_casesensitive(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.2.0", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_search_return_fields(client): res = client.json().set( "doc:1", @@ -1708,6 +1837,11 @@ def test_search_return_fields(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_synupdate(client): definition = IndexDefinition(index_type=IndexType.HASH) client.ft().create_index( @@ -1754,6 +1888,11 @@ def test_syndump(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.2.0", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_create_json_with_alias(client): """ Create definition with IndexType.JSON as index type (ON JSON) with two @@ -1799,6 +1938,11 @@ def test_create_json_with_alias(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.2.0", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_json_with_multipath(client): """ Create definition with IndexType.JSON as index type (ON JSON), @@ -1843,6 +1987,11 @@ def test_json_with_multipath(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.2.0", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_json_with_jsonpath(client): definition = IndexDefinition(index_type=IndexType.JSON) client.ft().create_index( @@ -1894,6 +2043,11 @@ def test_json_with_jsonpath(client): @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_if_redis_enterprise() +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_profile(client): client.ft().create_index((TextField("t"),)) client.ft().client.hset("1", "t", "hello") @@ -1903,10 +2057,11 @@ def test_profile(client): q = Query("hello|world").no_content() if is_resp2_connection(client): res, det = client.ft().profile(q) - assert det["Iterators profile"]["Counter"] == 2.0 - assert len(det["Iterators profile"]["Child iterators"]) == 2 - assert det["Iterators profile"]["Type"] == "UNION" - assert det["Parsing time"] < 0.5 + iterators_profile = det["Shards"][0]["Iterators profile"] + assert iterators_profile["Counter"] == 2 + assert len(iterators_profile["Child iterators"]) == 2 + assert iterators_profile["Type"] == "UNION" + assert float(det["Shards"][0]["Parsing time"]) < 0.5 assert len(res.docs) == 2 # check also the search result # check using AggregateRequest @@ -1916,16 +2071,17 @@ def test_profile(client): .apply(prefix="startswith(@t, 'hel')") ) res, det = client.ft().profile(req) - assert det["Iterators profile"]["Counter"] == 2 - assert det["Iterators profile"]["Type"] == "WILDCARD" - assert isinstance(det["Parsing time"], float) + iterators_profile = det["Shards"][0]["Iterators profile"] + assert iterators_profile["Counter"] == 2 + assert iterators_profile["Type"] == "WILDCARD" assert len(res.rows) == 2 # check also the search result else: res = client.ft().profile(q) - assert res["profile"]["Iterators profile"][0]["Counter"] == 2.0 - assert res["profile"]["Iterators profile"][0]["Type"] == "UNION" - assert res["profile"]["Parsing time"] < 0.5 - assert len(res["results"]) == 2 # check also the search result + iterators_profile = res["Profile"]["Shards"][0]["Iterators profile"] + assert iterators_profile["Counter"] == 2 + assert iterators_profile["Type"] == "UNION" + assert res["Profile"]["Shards"][0]["Parsing time"] < 0.5 + assert len(res["Results"]["results"]) == 2 # check also the search result # check using AggregateRequest req = ( @@ -1934,14 +2090,20 @@ def test_profile(client): .apply(prefix="startswith(@t, 'hel')") ) res = client.ft().profile(req) - assert res["profile"]["Iterators profile"][0]["Counter"] == 2 - assert res["profile"]["Iterators profile"][0]["Type"] == "WILDCARD" - assert isinstance(res["profile"]["Parsing time"], float) - assert len(res["results"]) == 2 # check also the search result + iterators_profile = res["Profile"]["Shards"][0]["Iterators profile"] + assert iterators_profile["Counter"] == 2 + assert iterators_profile["Type"] == "WILDCARD" + assert isinstance(res["Profile"]["Shards"][0]["Parsing time"], float) + assert len(res["Results"]["results"]) == 2 # check also the search result @pytest.mark.redismod @pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_profile_limited(client): client.ft().create_index((TextField("t"),)) client.ft().client.hset("1", "t", "hello") @@ -1952,33 +2114,39 @@ def test_profile_limited(client): q = Query("%hell% hel*") if is_resp2_connection(client): res, det = client.ft().profile(q, limited=True) + iterators_profile = det["Shards"][0]["Iterators profile"] assert ( - det["Iterators profile"]["Child iterators"][0]["Child iterators"] + iterators_profile["Child iterators"][0]["Child iterators"] == "The number of iterators in the union is 3" ) assert ( - det["Iterators profile"]["Child iterators"][1]["Child iterators"] + iterators_profile["Child iterators"][1]["Child iterators"] == "The number of iterators in the union is 4" ) - assert det["Iterators profile"]["Type"] == "INTERSECT" + assert iterators_profile["Type"] == "INTERSECT" assert len(res.docs) == 3 # check also the search result else: res = client.ft().profile(q, limited=True) - iterators_profile = res["profile"]["Iterators profile"] + iterators_profile = res["Profile"]["Shards"][0]["Iterators profile"] assert ( - iterators_profile[0]["Child iterators"][0]["Child iterators"] + iterators_profile["Child iterators"][0]["Child iterators"] == "The number of iterators in the union is 3" ) assert ( - iterators_profile[0]["Child iterators"][1]["Child iterators"] + iterators_profile["Child iterators"][1]["Child iterators"] == "The number of iterators in the union is 4" ) - assert iterators_profile[0]["Type"] == "INTERSECT" - assert len(res["results"]) == 3 # check also the search result + assert iterators_profile["Type"] == "INTERSECT" + assert len(res["Results"]["results"]) == 3 # check also the search result @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_profile_query_params(client): client.ft().create_index( ( @@ -1994,22 +2162,29 @@ def test_profile_query_params(client): q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) if is_resp2_connection(client): res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "VECTOR" + iterators_profile = det["Shards"][0]["Iterators profile"] + assert iterators_profile["Counter"] == 2.0 + assert iterators_profile["Type"] == "VECTOR" assert res.total == 2 assert "a" == res.docs[0].id assert "0" == res.docs[0].__getattribute__("__v_score") else: res = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - assert res["profile"]["Iterators profile"][0]["Counter"] == 2 - assert res["profile"]["Iterators profile"][0]["Type"] == "VECTOR" - assert res["total_results"] == 2 - assert "a" == res["results"][0]["id"] - assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + iterators_profile = res["Profile"]["Shards"][0]["Iterators profile"] + assert iterators_profile["Counter"] == 2 + assert iterators_profile["Type"] == "VECTOR" + assert res["Results"]["total_results"] == 2 + assert "a" == res["Results"]["results"][0]["id"] + assert "0" == res["Results"]["results"][0]["extra_attributes"]["__v_score"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_vector_field(client): client.flushdb() client.ft().create_index( @@ -2051,6 +2226,11 @@ def test_vector_field_error(r): @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_text_params(client): client.flushdb() client.ft().create_index((TextField("name"),)) @@ -2074,6 +2254,11 @@ def test_text_params(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_numeric_params(client): client.flushdb() client.ft().create_index((NumericField("numval"),)) @@ -2098,6 +2283,11 @@ def test_numeric_params(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_geo_params(client): client.ft().create_index((GeoField("g"))) client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) @@ -2121,6 +2311,11 @@ def test_geo_params(client): @pytest.mark.redismod @skip_if_redis_enterprise() +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_search_commands_in_pipeline(client): p = client.ft().pipeline() p.create_index((TextField("txt"),)) @@ -2195,6 +2390,11 @@ def test_dialect(client): @pytest.mark.redismod +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_expire_while_search(client: redis.Redis): client.ft().create_index((TextField("txt"),)) client.hset("hset:1", "txt", "a") @@ -2218,6 +2418,11 @@ def test_expire_while_search(client: redis.Redis): @pytest.mark.redismod @pytest.mark.experimental +@pytest.mark.parametrize( + "client", + [{"protocol": 2}, {"protocol": 3}], + indirect=True, +) def test_withsuffixtrie(client: redis.Redis): # create index assert client.ft().create_index((TextField("txt"),))