Skip to content

Commit 889534d

Browse files
committed
Refactor message parsing for more actions
1 parent 9b038aa commit 889534d

File tree

2 files changed

+187
-63
lines changed

2 files changed

+187
-63
lines changed

pinecone/grpc/utils.py

Lines changed: 119 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any
1+
from __future__ import annotations
2+
3+
from typing import Any, TYPE_CHECKING
24
from google.protobuf import json_format
35
from google.protobuf.message import Message
46

@@ -28,6 +30,17 @@
2830

2931
from google.protobuf.struct_pb2 import Struct
3032

33+
if TYPE_CHECKING:
34+
from pinecone.core.grpc.protos.db_data_2025_10_pb2 import (
35+
FetchResponse as ProtoFetchResponse,
36+
FetchByMetadataResponse as ProtoFetchByMetadataResponse,
37+
QueryResponse as ProtoQueryResponse,
38+
UpsertResponse as ProtoUpsertResponse,
39+
UpdateResponse as ProtoUpdateResponse,
40+
NamespaceDescription as ProtoNamespaceDescription,
41+
ListNamespacesResponse as ProtoListNamespacesResponse,
42+
)
43+
3144

3245
def _generate_request_id() -> str:
3346
return str(uuid.uuid4())
@@ -53,7 +66,7 @@ def parse_sparse_values(sparse_values: dict | None) -> SparseValues:
5366

5467

5568
def parse_fetch_response(
56-
response: Message, initial_metadata: dict[str, str] | None = None
69+
response: "ProtoFetchResponse", initial_metadata: dict[str, str] | None = None
5770
) -> FetchResponse:
5871
"""Parse a FetchResponse protobuf message directly without MessageToDict conversion.
5972
@@ -105,36 +118,61 @@ def parse_fetch_response(
105118

106119

107120
def parse_fetch_by_metadata_response(
108-
response: Message, initial_metadata: dict[str, str] | None = None
121+
response: "ProtoFetchByMetadataResponse", initial_metadata: dict[str, str] | None = None
109122
) -> FetchByMetadataResponse:
110-
json_response = json_format.MessageToDict(response)
111-
112-
vd = {}
113-
vectors = json_response.get("vectors", {})
114-
namespace = json_response.get("namespace", "")
115-
116-
for id, vec in vectors.items():
117-
vd[id] = _Vector(
118-
id=vec["id"],
119-
values=vec.get("values", None),
120-
sparse_values=parse_sparse_values(vec.get("sparseValues", None)),
121-
metadata=vec.get("metadata", None),
122-
_check_type=False,
123-
)
124-
125-
pagination = None
126-
if json_response.get("pagination") and json_response["pagination"].get("next"):
127-
pagination = Pagination(next=json_response["pagination"]["next"])
123+
"""Parse a FetchByMetadataResponse protobuf message directly without MessageToDict conversion.
128124
125+
This optimized version directly accesses protobuf fields for better performance.
126+
"""
129127
# Extract response info from initial metadata
130128
from pinecone.utils.response_info import extract_response_info
131129

132130
metadata = initial_metadata or {}
133131
response_info = extract_response_info(metadata)
134132

133+
# Directly access protobuf fields instead of converting entire message to dict
134+
vd = {}
135+
# namespace is a required string field, so it will always have a value (default empty string)
136+
namespace = response.namespace
137+
138+
# Iterate over vectors map directly
139+
for vec_id, vec in response.vectors.items():
140+
# Convert vector.values (RepeatedScalarFieldContainer) to list
141+
values = list(vec.values) if vec.values else None
142+
143+
# Handle sparse_values if present
144+
parsed_sparse = None
145+
if vec.HasField("sparse_values") and vec.sparse_values:
146+
parsed_sparse = parse_sparse_values(
147+
{
148+
"indices": list(vec.sparse_values.indices),
149+
"values": list(vec.sparse_values.values),
150+
}
151+
)
152+
153+
# Convert metadata Struct to dict only when needed
154+
metadata_dict = None
155+
if vec.HasField("metadata") and vec.metadata:
156+
metadata_dict = json_format.MessageToDict(vec.metadata)
157+
158+
vd[vec_id] = _Vector(
159+
id=vec.id,
160+
values=values,
161+
sparse_values=parsed_sparse,
162+
metadata=metadata_dict,
163+
_check_type=False,
164+
)
165+
166+
# Parse pagination if present
167+
pagination = None
168+
if response.HasField("pagination") and response.pagination:
169+
pagination = Pagination(next=response.pagination.next)
170+
171+
# Parse usage if present
135172
usage = None
136-
if json_response.get("usage"):
137-
usage = parse_usage(json_response.get("usage", {}))
173+
if response.HasField("usage") and response.usage:
174+
usage = parse_usage({"readUnits": response.usage.read_units})
175+
138176
fetch_by_metadata_response = FetchByMetadataResponse(
139177
vectors=vd,
140178
namespace=namespace,
@@ -153,43 +191,50 @@ def parse_usage(usage: dict) -> Usage:
153191

154192

155193
def parse_upsert_response(
156-
response: Message, _check_type: bool = False, initial_metadata: dict[str, str] | None = None
194+
response: "ProtoUpsertResponse",
195+
_check_type: bool = False,
196+
initial_metadata: dict[str, str] | None = None,
157197
) -> UpsertResponse:
158-
from pinecone.utils.response_info import extract_response_info
198+
"""Parse an UpsertResponse protobuf message directly without MessageToDict conversion.
159199
160-
json_response = json_format.MessageToDict(response)
161-
upserted_count = json_response.get("upsertedCount", 0)
200+
This optimized version directly accesses protobuf fields for better performance.
201+
"""
202+
from pinecone.utils.response_info import extract_response_info
162203

163204
# Extract response info from initial metadata
164205
# For gRPC, LSN headers are in initial_metadata
165206
metadata = initial_metadata or {}
166207
response_info = extract_response_info(metadata)
167208

209+
# Directly access upserted_count field (required field in proto3, always has a value)
210+
upserted_count = response.upserted_count
211+
168212
return UpsertResponse(upserted_count=int(upserted_count), _response_info=response_info)
169213

170214

171215
def parse_update_response(
172-
response: dict | Message,
216+
response: dict | "ProtoUpdateResponse",
173217
_check_type: bool = False,
174218
initial_metadata: dict[str, str] | None = None,
175219
) -> UpdateResponse:
220+
"""Parse an UpdateResponse protobuf message directly without MessageToDict conversion.
221+
222+
This optimized version directly accesses protobuf fields for better performance.
223+
For dict responses (REST API), falls back to the original dict-based parsing.
224+
"""
176225
from pinecone.utils.response_info import extract_response_info
177-
from google.protobuf import json_format
178226

179227
# Extract response info from initial metadata
180228
metadata = initial_metadata or {}
181229
response_info = extract_response_info(metadata)
182230

183231
# Extract matched_records from response
184232
matched_records = None
185-
if isinstance(response, Message):
186-
# GRPC response - convert to dict to extract matched_records
187-
json_response = json_format.MessageToDict(response)
188-
matched_records = json_response.get("matchedRecords") or json_response.get(
189-
"matched_records"
190-
)
233+
if isinstance(response, Message) and not isinstance(response, dict):
234+
# Optimized path: directly access protobuf field
235+
matched_records = response.matched_records if response.HasField("matched_records") else None
191236
elif isinstance(response, dict):
192-
# Dict response - extract directly
237+
# Fallback for dict responses (REST API)
193238
matched_records = response.get("matchedRecords") or response.get("matched_records")
194239

195240
return UpdateResponse(matched_records=matched_records, _response_info=response_info)
@@ -211,7 +256,7 @@ def parse_delete_response(
211256

212257

213258
def parse_query_response(
214-
response: dict | Message,
259+
response: dict | "ProtoQueryResponse",
215260
_check_type: bool = False,
216261
initial_metadata: dict[str, str] | None = None,
217262
) -> QueryResponse:
@@ -226,7 +271,7 @@ def parse_query_response(
226271
metadata = initial_metadata or {}
227272
response_info = extract_response_info(metadata)
228273

229-
if isinstance(response, Message):
274+
if isinstance(response, Message) and not isinstance(response, dict):
230275
# Optimized path: directly access protobuf fields
231276
matches = []
232277
# namespace is a required string field, so it will always have a value (default empty string)
@@ -320,26 +365,30 @@ def parse_stats_response(response: dict) -> "DescribeIndexStatsResponse":
320365

321366

322367
def parse_namespace_description(
323-
response: Message, initial_metadata: dict[str, str] | None = None
368+
response: "ProtoNamespaceDescription", initial_metadata: dict[str, str] | None = None
324369
) -> NamespaceDescription:
370+
"""Parse a NamespaceDescription protobuf message directly without MessageToDict conversion.
371+
372+
This optimized version directly accesses protobuf fields for better performance.
373+
"""
325374
from pinecone.utils.response_info import extract_response_info
326375

327-
json_response = json_format.MessageToDict(response)
376+
# Directly access protobuf fields
377+
name = response.name
378+
record_count = response.record_count
328379

329380
# Extract indexed_fields if present
330381
indexed_fields = None
331-
if "indexedFields" in json_response and json_response["indexedFields"]:
332-
indexed_fields_data = json_response["indexedFields"]
333-
if "fields" in indexed_fields_data:
382+
if response.HasField("indexed_fields") and response.indexed_fields:
383+
# Access indexed_fields.fields directly (RepeatedScalarFieldContainer)
384+
fields_list = list(response.indexed_fields.fields) if response.indexed_fields.fields else []
385+
if fields_list:
334386
indexed_fields = NamespaceDescriptionIndexedFields(
335-
fields=indexed_fields_data.get("fields", []), _check_type=False
387+
fields=fields_list, _check_type=False
336388
)
337389

338390
namespace_desc = NamespaceDescription(
339-
name=json_response.get("name", ""),
340-
record_count=json_response.get("recordCount", 0),
341-
indexed_fields=indexed_fields,
342-
_check_type=False,
391+
name=name, record_count=record_count, indexed_fields=indexed_fields, _check_type=False
343392
)
344393

345394
# Attach _response_info as an attribute (NamespaceDescription is an OpenAPI model)
@@ -352,36 +401,44 @@ def parse_namespace_description(
352401
return cast(NamespaceDescription, namespace_desc)
353402

354403

355-
def parse_list_namespaces_response(response: Message) -> ListNamespacesResponse:
356-
json_response = json_format.MessageToDict(response)
404+
def parse_list_namespaces_response(
405+
response: "ProtoListNamespacesResponse",
406+
) -> ListNamespacesResponse:
407+
"""Parse a ListNamespacesResponse protobuf message directly without MessageToDict conversion.
357408
409+
This optimized version directly accesses protobuf fields for better performance.
410+
"""
411+
# Directly iterate over namespaces
358412
namespaces = []
359-
for ns in json_response.get("namespaces", []):
413+
for ns in response.namespaces:
360414
# Extract indexed_fields if present
361415
indexed_fields = None
362-
if "indexedFields" in ns and ns["indexedFields"]:
363-
indexed_fields_data = ns["indexedFields"]
364-
if "fields" in indexed_fields_data:
416+
if ns.HasField("indexed_fields") and ns.indexed_fields:
417+
# Access indexed_fields.fields directly (RepeatedScalarFieldContainer)
418+
fields_list = list(ns.indexed_fields.fields) if ns.indexed_fields.fields else []
419+
if fields_list:
365420
indexed_fields = NamespaceDescriptionIndexedFields(
366-
fields=indexed_fields_data.get("fields", []), _check_type=False
421+
fields=fields_list, _check_type=False
367422
)
368423

369424
namespaces.append(
370425
NamespaceDescription(
371-
name=ns.get("name", ""),
372-
record_count=ns.get("recordCount", 0),
426+
name=ns.name,
427+
record_count=ns.record_count,
373428
indexed_fields=indexed_fields,
374429
_check_type=False,
375430
)
376431
)
377432

433+
# Parse pagination if present
378434
pagination = None
379-
if "pagination" in json_response and json_response["pagination"]:
380-
pagination = OpenApiPagination(
381-
next=json_response["pagination"].get("next", ""), _check_type=False
382-
)
435+
if response.HasField("pagination") and response.pagination:
436+
pagination = OpenApiPagination(next=response.pagination.next, _check_type=False)
437+
438+
# Parse total_count (int field in proto3, always has a value, default 0)
439+
# If 0, treat as None to match original behavior
440+
total_count = response.total_count if response.total_count else None
383441

384-
total_count = json_response.get("totalCount")
385442
from typing import cast
386443

387444
result = ListNamespacesResponse(

tests/perf/test_grpc_parsing_perf.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,22 @@
1111
from pinecone.core.grpc.protos.db_data_2025_10_pb2 import (
1212
FetchResponse,
1313
QueryResponse,
14+
FetchByMetadataResponse,
15+
UpsertResponse,
16+
UpdateResponse,
1417
Vector,
1518
ScoredVector,
1619
SparseValues,
1720
Usage,
21+
Pagination,
22+
)
23+
from pinecone.grpc.utils import (
24+
parse_fetch_response,
25+
parse_query_response,
26+
parse_fetch_by_metadata_response,
27+
parse_upsert_response,
28+
parse_update_response,
1829
)
19-
from pinecone.grpc.utils import parse_fetch_response, parse_query_response
2030

2131

2232
def create_vector(id: str, dimension: int, include_sparse: bool = False) -> Vector:
@@ -104,6 +114,35 @@ def create_query_response(
104114
)
105115

106116

117+
def create_fetch_by_metadata_response(
118+
num_vectors: int, dimension: int, include_sparse: bool = False
119+
) -> FetchByMetadataResponse:
120+
"""Create a FetchByMetadataResponse protobuf message with specified number of vectors."""
121+
vectors = {}
122+
for i in range(num_vectors):
123+
vector = create_vector(f"vec_{i}", dimension, include_sparse)
124+
vectors[f"vec_{i}"] = vector
125+
126+
pagination = Pagination(next="next_token") if num_vectors > 10 else None
127+
128+
return FetchByMetadataResponse(
129+
vectors=vectors,
130+
namespace="test_namespace",
131+
usage=Usage(read_units=num_vectors),
132+
pagination=pagination,
133+
)
134+
135+
136+
def create_upsert_response(upserted_count: int) -> UpsertResponse:
137+
"""Create an UpsertResponse protobuf message."""
138+
return UpsertResponse(upserted_count=upserted_count)
139+
140+
141+
def create_update_response(matched_records: int) -> UpdateResponse:
142+
"""Create an UpdateResponse protobuf message."""
143+
return UpdateResponse(matched_records=matched_records)
144+
145+
107146
class TestFetchResponseParsingPerf:
108147
"""Performance benchmarks for parse_fetch_response."""
109148

@@ -160,3 +199,31 @@ def test_parse_query_response_sparse(self, benchmark, num_matches, dimension):
160199
"""Benchmark parse_query_response with sparse vectors."""
161200
response = create_query_response(num_matches, dimension, include_sparse=True)
162201
benchmark(parse_query_response, response, False, None)
202+
203+
204+
class TestFetchByMetadataResponseParsingPerf:
205+
"""Performance benchmarks for parse_fetch_by_metadata_response."""
206+
207+
@pytest.mark.parametrize("num_vectors,dimension", [(10, 128), (100, 128), (1000, 128)])
208+
def test_parse_fetch_by_metadata_response_dense(self, benchmark, num_vectors, dimension):
209+
"""Benchmark parse_fetch_by_metadata_response with dense vectors."""
210+
response = create_fetch_by_metadata_response(num_vectors, dimension, include_sparse=False)
211+
benchmark(parse_fetch_by_metadata_response, response, None)
212+
213+
214+
class TestUpsertResponseParsingPerf:
215+
"""Performance benchmarks for parse_upsert_response."""
216+
217+
def test_parse_upsert_response(self, benchmark):
218+
"""Benchmark parse_upsert_response."""
219+
response = create_upsert_response(upserted_count=100)
220+
benchmark(parse_upsert_response, response, False, None)
221+
222+
223+
class TestUpdateResponseParsingPerf:
224+
"""Performance benchmarks for parse_update_response."""
225+
226+
def test_parse_update_response(self, benchmark):
227+
"""Benchmark parse_update_response."""
228+
response = create_update_response(matched_records=50)
229+
benchmark(parse_update_response, response, False, None)

0 commit comments

Comments
 (0)