Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(integrations): Add integration for qdrant #3623

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sentry_sdk/consts.py
Original file line number Diff line number Diff line change
@@ -426,6 +426,8 @@ class OP:
COHERE_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.cohere"
COHERE_EMBEDDINGS_CREATE = "ai.embeddings.create.cohere"
DB = "db"
DB_QDRANT_GRPC = "db.qdrant.grpc"
DB_QDRANT_REST = "db.qdrant.rest"
DB_REDIS = "db.redis"
EVENT_DJANGO = "event.django"
FUNCTION = "function"
40 changes: 40 additions & 0 deletions sentry_sdk/integrations/qdrant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from sentry_sdk.integrations import DidNotEnable

try:
from qdrant_client.http import ApiClient, AsyncApiClient
import grpc
except ImportError:
raise DidNotEnable("Qdrant client not installed")

from sentry_sdk.integrations import Integration
from sentry_sdk.integrations.qdrant.consts import _IDENTIFIER
from sentry_sdk.integrations.qdrant.qdrant import (
_sync_api_client_send_inner,
_async_api_client_send_inner,
_wrap_channel_sync,
_wrap_channel_async,
)


class QdrantIntegration(Integration):
identifier = _IDENTIFIER

def __init__(self, mute_children_http_spans=True):
# type: (bool) -> None
self.mute_children_http_spans = mute_children_http_spans

@staticmethod
def setup_once():
# type: () -> None

# hooks for the REST client
ApiClient.send_inner = _sync_api_client_send_inner(ApiClient.send_inner)
AsyncApiClient.send_inner = _async_api_client_send_inner(
AsyncApiClient.send_inner
)

# hooks for the gRPC client
grpc.secure_channel = _wrap_channel_sync(grpc.secure_channel)
grpc.insecure_channel = _wrap_channel_sync(grpc.insecure_channel)
grpc.aio.secure_channel = _wrap_channel_async(grpc.aio.secure_channel)
grpc.aio.insecure_channel = _wrap_channel_async(grpc.aio.insecure_channel)
121 changes: 121 additions & 0 deletions sentry_sdk/integrations/qdrant/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from sentry_sdk.integrations.qdrant.path_matching import PathTrie

SPAN_ORIGIN = "auto.db.qdrant"

# created from https://github.com/qdrant/qdrant/blob/master/docs/redoc/v1.11.x/openapi.json
# only used for qdrants REST API. gRPC is using other identifiers
_PATH_TO_OPERATION_ID = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it is such a good idea to hardcode this dictionary based on something from QDrant which could change in future QDrant versions. It would be better to somehow obtain this information from QDrant at runtime, to maintain compatibility with future versions.

"/collections/{collection_name}/shards": {"put": "create_shard_key"},
"/collections/{collection_name}/shards/delete": {"post": "delete_shard_key"},
"/": {"get": "root"},
"/telemetry": {"get": "telemetry"},
"/metrics": {"get": "metrics"},
"/locks": {"post": "post_locks", "get": "get_locks"},
"/healthz": {"get": "healthz"},
"/livez": {"get": "livez"},
"/readyz": {"get": "readyz"},
"/issues": {"get": "get_issues", "delete": "clear_issues"},
"/cluster": {"get": "cluster_status"},
"/cluster/recover": {"post": "recover_current_peer"},
"/cluster/peer/{peer_id}": {"delete": "remove_peer"},
"/collections": {"get": "get_collections"},
"/collections/{collection_name}": {
"get": "get_collection",
"put": "create_collection",
"patch": "update_collection",
"delete": "delete_collection",
},
"/collections/aliases": {"post": "update_aliases"},
"/collections/{collection_name}/index": {"put": "create_field_index"},
"/collections/{collection_name}/exists": {"get": "collection_exists"},
"/collections/{collection_name}/index/{field_name}": {
"delete": "delete_field_index"
},
"/collections/{collection_name}/cluster": {
"get": "collection_cluster_info",
"post": "update_collection_cluster",
},
"/collections/{collection_name}/aliases": {"get": "get_collection_aliases"},
"/aliases": {"get": "get_collections_aliases"},
"/collections/{collection_name}/snapshots/upload": {
"post": "recover_from_uploaded_snapshot"
},
"/collections/{collection_name}/snapshots/recover": {
"put": "recover_from_snapshot"
},
"/collections/{collection_name}/snapshots": {
"get": "list_snapshots",
"post": "create_snapshot",
},
"/collections/{collection_name}/snapshots/{snapshot_name}": {
"delete": "delete_snapshot",
"get": "get_snapshot",
},
"/snapshots": {"get": "list_full_snapshots", "post": "create_full_snapshot"},
"/snapshots/{snapshot_name}": {
"delete": "delete_full_snapshot",
"get": "get_full_snapshot",
},
"/collections/{collection_name}/shards/{shard_id}/snapshots/upload": {
"post": "recover_shard_from_uploaded_snapshot"
},
"/collections/{collection_name}/shards/{shard_id}/snapshots/recover": {
"put": "recover_shard_from_snapshot"
},
"/collections/{collection_name}/shards/{shard_id}/snapshots": {
"get": "list_shard_snapshots",
"post": "create_shard_snapshot",
},
"/collections/{collection_name}/shards/{shard_id}/snapshots/{snapshot_name}": {
"delete": "delete_shard_snapshot",
"get": "get_shard_snapshot",
},
"/collections/{collection_name}/points/{id}": {"get": "get_point"},
"/collections/{collection_name}/points": {
"post": "get_points",
"put": "upsert_points",
},
"/collections/{collection_name}/points/delete": {"post": "delete_points"},
"/collections/{collection_name}/points/vectors": {"put": "update_vectors"},
"/collections/{collection_name}/points/vectors/delete": {"post": "delete_vectors"},
"/collections/{collection_name}/points/payload": {
"post": "set_payload",
"put": "overwrite_payload",
},
"/collections/{collection_name}/points/payload/delete": {"post": "delete_payload"},
"/collections/{collection_name}/points/payload/clear": {"post": "clear_payload"},
"/collections/{collection_name}/points/batch": {"post": "batch_update"},
"/collections/{collection_name}/points/scroll": {"post": "scroll_points"},
"/collections/{collection_name}/points/search": {"post": "search_points"},
"/collections/{collection_name}/points/search/batch": {
"post": "search_batch_points"
},
"/collections/{collection_name}/points/search/groups": {
"post": "search_point_groups"
},
"/collections/{collection_name}/points/recommend": {"post": "recommend_points"},
"/collections/{collection_name}/points/recommend/batch": {
"post": "recommend_batch_points"
},
"/collections/{collection_name}/points/recommend/groups": {
"post": "recommend_point_groups"
},
"/collections/{collection_name}/points/discover": {"post": "discover_points"},
"/collections/{collection_name}/points/discover/batch": {
"post": "discover_batch_points"
},
"/collections/{collection_name}/points/count": {"post": "count_points"},
"/collections/{collection_name}/points/query": {"post": "query_points"},
"/collections/{collection_name}/points/query/batch": {"post": "query_batch_points"},
"/collections/{collection_name}/points/query/groups": {
"post": "query_points_groups"
},
}

_DISALLOWED_PROTO_FIELDS = {"data", "keyword"}

_DISALLOWED_REST_FIELDS = {"nearest", "value"}

_IDENTIFIER = "qdrant"

_qdrant_trie = PathTrie(_PATH_TO_OPERATION_ID)
144 changes: 144 additions & 0 deletions sentry_sdk/integrations/qdrant/path_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Any, Dict, Optional, List


class TrieNode:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to define a custom data structure here? Is there no way to do this with one of the APIs exposed by QDrant or with one of the built-in data structures?

def __init__(self, is_placeholder=False):
"""
Initializes a TrieNode.
:param is_placeholder: Indicates if this node represents a placeholder (wildcard).
"""
self.children = {} # type: Dict[str, 'TrieNode']
self.operation_ids = {} # type: Dict[str, str]
self.is_placeholder = is_placeholder # type: bool

@classmethod
def from_dict(cls, data, parent_path=""):
# type: (Dict[str, Any], str) -> 'TrieNode'
"""
Recursively constructs a TrieNode from a nested dictionary.
:param data: Nested dictionary mapping path segments to either nested dictionaries
or dictionaries of HTTP methods to operation IDs.
:param parent_path: The accumulated path from the root to the current node.
:return: Root TrieNode of the constructed trie.
"""
node = cls()
for path, methods in data.items():
segments = PathTrie.split_path(path)
current = node
for segment in segments:
is_placeholder = segment.startswith("{") and segment.endswith("}")
key = "*" if is_placeholder else segment

if key not in current.children:
current.children[key] = TrieNode(is_placeholder=is_placeholder)
current = current.children[key]

if isinstance(methods, dict):
for method, operation_id in methods.items():
current.operation_ids[method.lower()] = operation_id

return node

def to_dict(self, current_path=""):
# type: (str) -> Dict[str, Any]
"""
Serializes the TrieNode and its children back to a nested dictionary.
:param current_path: The accumulated path from the root to the current node.
:return: Nested dictionary representing the trie.
"""
result = {} # type: Dict[str, Any]
if self.operation_ids:
path_key = current_path or "/"
result[path_key] = self.operation_ids.copy()

for segment, child in self.children.items():
# replace wildcard '*' back to placeholder format if necessary.
# allows for TrieNode.from_dict(TrieNode.to_dict()) to be idempotent.
display_segment = "{placeholder}" if child.is_placeholder else segment
new_path = (
f"{current_path}/{display_segment}"
if current_path
else f"/{display_segment}"
)
child_dict = child.to_dict(new_path)
result.update(child_dict)

return result


class PathTrie:
WILDCARD = "*" # type: str

def __init__(self, data=None):
# type: (Optional[Dict[str, Any]]) -> None
"""
Initializes the PathTrie with optional initial data.
:param data: Optional nested dictionary to initialize the trie.
"""
self.root = TrieNode.from_dict(data or {}) # type: TrieNode

def insert(self, path, method, operation_id):
# type: (str, str, str) -> None
"""
Inserts a path into the trie with its corresponding HTTP method and operation ID.
:param path: The API path (e.g., '/users/{user_id}/posts').
:param method: HTTP method (e.g., 'GET', 'POST').
:param operation_id: The operation identifier associated with the path and method.
"""
current = self.root
segments = self.split_path(path)

for segment in segments:
is_placeholder = self._is_placeholder(segment)
key = self.WILDCARD if is_placeholder else segment

if key not in current.children:
current.children[key] = TrieNode(is_placeholder=is_placeholder)
current = current.children[key]

current.operation_ids[method.lower()] = operation_id

def match(self, path, method):
# type: (str, str) -> Optional[str]
"""
Matches a given path and HTTP method to its corresponding operation ID.
:param path: The API path to match.
:param method: HTTP method to match.
:return: The operation ID if a match is found; otherwise, None.
"""
current = self.root
segments = self.split_path(path)

for segment in segments:
if segment in current.children:
current = current.children[segment]
elif self.WILDCARD in current.children:
current = current.children[self.WILDCARD]
else:
return None

return current.operation_ids.get(method.lower())

def to_dict(self):
# type: () -> Dict[str, Any]
return self.root.to_dict()

@staticmethod
def split_path(path):
# type: (str) -> List[str]
return [segment for segment in path.strip("/").split("/") if segment]

@staticmethod
def _is_placeholder(segment):
# type: (str) -> bool
return segment.startswith("{") and segment.endswith("}")

def __repr__(self):
# type: () -> str
return f"PathTrie({self.to_dict()})"
523 changes: 523 additions & 0 deletions sentry_sdk/integrations/qdrant/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,523 @@
import json
from contextlib import contextmanager
from decimal import Decimal
from functools import wraps
from typing import TYPE_CHECKING
from urllib.parse import urlparse

from google.protobuf.descriptor import FieldDescriptor
from httpx import Response, Request

import grpc

import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations.httpx import HttpxIntegration
from sentry_sdk.integrations.qdrant.consts import (
_DISALLOWED_REST_FIELDS,
_DISALLOWED_PROTO_FIELDS,
_qdrant_trie,
_IDENTIFIER,
)

# Hack to get new Python features working in older versions
# without introducing a hard dependency on `typing_extensions`
# from: https://stackoverflow.com/a/71944042/300572
# taken from sentry_sdk.integrations.grpc.__init__.py
Comment on lines +23 to +26
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, see comment below

Suggested change
# Hack to get new Python features working in older versions
# without introducing a hard dependency on `typing_extensions`
# from: https://stackoverflow.com/a/71944042/300572
# taken from sentry_sdk.integrations.grpc.__init__.py

if TYPE_CHECKING:
from typing import (
Any,
Callable,
Awaitable,
ParamSpec,
Optional,
Sequence,
Generator,
)
from grpc import Channel, ClientCallDetails
from grpc.aio import UnaryUnaryCall
from grpc.aio import Channel as AsyncChannel
from google.protobuf.message import Message
from sentry_sdk.tracing import Span, Transaction, _SpanRecorder
from sentry_sdk.integrations.qdrant import QdrantIntegration
Comment on lines +37 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these only being used in type comments? If not, they should be imported outside the if TYPE_CHECKING block

else:
# Fake ParamSpec
class ParamSpec:
def __init__(self, _):
self.args = None
self.kwargs = None

# Callable[anything] will return None
class _Callable:
def __getitem__(self, _):
return None

# Make instances
Callable = _Callable()
Comment on lines +43 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can (and should) remove this, because you are using type comments.

Suggested change
else:
# Fake ParamSpec
class ParamSpec:
def __init__(self, _):
self.args = None
self.kwargs = None
# Callable[anything] will return None
class _Callable:
def __getitem__(self, _):
return None
# Make instances
Callable = _Callable()


P = ParamSpec("P")


def _remove_httpx_span(span, operation_id):
# type: (Span, str) -> None
current_transaction = span.containing_transaction # type: Optional[Transaction]
if not current_transaction:
return

span_recorder = current_transaction._span_recorder # type: Optional[_SpanRecorder]
if not span_recorder:
return

try:
current_span_index = span_recorder.spans.index(span)
except ValueError:
# what ?
return

next_span_index = current_span_index + 1
if next_span_index >= len(span_recorder.spans):
return

next_span = span_recorder.spans[next_span_index] # type: Span

# check if the next span is an HTTPX client span
if next_span.op != OP.HTTP_CLIENT or next_span.origin != HttpxIntegration.origin:
return

httpx_span_description = next_span.description # type: Optional[str]
if not httpx_span_description:
return

try:
httpx_method, httpx_url = httpx_span_description.split(" ", 1)
except ValueError:
# unexpected span name format
return

parsed_url = urlparse(httpx_url)
httpx_path = parsed_url.path

# just to be *really* sure that we don't accidentally delete an unrelated span
httpx_operation_id = _qdrant_trie.match(httpx_path, httpx_method)

if httpx_operation_id == operation_id:
span_recorder.spans.pop(next_span_index)


@contextmanager
def _prepare_span(request, mute_http_child_span):
# type: (Request, bool) -> Span
operation_id = _qdrant_trie.match(request.url.path, request.method)
payload = json.loads(request.content) if request.content else {}
payload["operation_id"] = operation_id

name = json.dumps(_strip_dict(payload, _DISALLOWED_REST_FIELDS), default=str)

span = sentry_sdk.start_span(
op=OP.DB_QDRANT_REST,
name=name,
origin=_IDENTIFIER,
)
span.set_data("db.api", "REST")
span.set_data(SPANDATA.DB_SYSTEM, "qdrant")
span.set_data(SPANDATA.DB_OPERATION, operation_id)

collection_name = getattr(request, "collection_name", None)
if collection_name:
span.set_data("db.qdrant.collection", collection_name)

yield span

# the HttpxIntegration will capture all REST calls, leading to almost duplicate spans but with less information.
# we mute the created span by httpx by default, but you can disable this in the Qdrant integration if you want.
if mute_http_child_span:
_remove_httpx_span(span, operation_id)

span.finish()


def _get_integration():
# type: () -> (Optional[QdrantIntegration])
from sentry_sdk.integrations.qdrant import QdrantIntegration

return sentry_sdk.get_client().get_integration(QdrantIntegration)


def _sync_api_client_send_inner(f):
# type: (Callable[P, Response]) -> Callable[P, Response]
@wraps(f)
def wrapper(*args, **kwargs):
# type: (P.args, P.kwargs) -> Response
integration = _get_integration() # type: Optional[QdrantIntegration]
if integration is None or not (len(args) >= 2 and isinstance(args[1], Request)):
return f(*args, **kwargs)

request = args[1] # type: Request
with _prepare_span(request, integration.mute_children_http_spans):
return f(*args, **kwargs)

return wrapper


def _async_api_client_send_inner(f):
# type: (Callable[P, Awaitable[Response]]) -> Callable[P, Awaitable[Response]]
@wraps(f)
async def wrapper(*args, **kwargs):
# type: (P.args, P.kwargs) -> Response
integration = _get_integration() # type: Optional[QdrantIntegration]
if integration is None or not (len(args) >= 2 and isinstance(args[1], Request)):
return f(*args, **kwargs)

request = args[1] # type: Request
with _prepare_span(request, integration.mute_children_http_spans):
return await f(*args, **kwargs)

return wrapper


# taken from grpc integration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use the grpc integration (or use this code from the grpc integration) directly?

def _wrap_channel_sync(f):
# type: (Callable[P, Channel]) -> Callable[P, Channel]

@wraps(f)
def patched_channel(*args, **kwargs):
# type: (P.args, P.kwargs) -> Channel
channel = f(*args, **kwargs)

if not ClientInterceptor._is_intercepted:
ClientInterceptor._is_intercepted = True
return grpc.intercept_channel(channel, ClientInterceptor())
else:
return channel

return patched_channel


def _wrap_channel_async(func):
# type: (Callable[P, AsyncChannel]) -> Callable[P, AsyncChannel]
"Wrapper for asynchronous secure and insecure channel."

@wraps(func)
def patched_channel(*args, interceptors, **kwargs):
# type: (P.args, Optional[Sequence[grpc.aio.ClientInterceptor]], P.kwargs) -> Channel
sentry_interceptors = [
AsyncClientInterceptor(),
]
interceptors = [*sentry_interceptors, *(interceptors or [])]
return func(*args, interceptors=interceptors, **kwargs)

return patched_channel


class GenericClientInterceptor:
_is_intercepted = False

@staticmethod
def _update_client_call_details_metadata_from_scope(client_call_details):
# type: (ClientCallDetails) -> ClientCallDetails
"""
Updates the metadata of the client call details by appending
trace propagation headers from the current Sentry scope.
"""
metadata = (
list(client_call_details.metadata) if client_call_details.metadata else []
)

# append sentrys trace propagation headers
for (
key,
value,
) in sentry_sdk.get_current_scope().iter_trace_propagation_headers():
metadata.append((key, value))

updated_call_details = grpc._interceptor._ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
compression=client_call_details.compression,
)

return updated_call_details

@staticmethod
def _should_intercept(method):
# type: (str) -> bool
"""
Determines whether the interceptor should process the given method.
"""
if not method.startswith("/qdrant."):
return False

from sentry_sdk.integrations.qdrant import QdrantIntegration

return sentry_sdk.get_client().get_integration(QdrantIntegration) is not None

@staticmethod
def _prepare_tags(method, request):
# type: (str, Message) -> dict
"""
Prepares the tags for the Sentry span based on the method and request.
"""
# qdrant uses this prefix for all its methods
operation = method[len("/qdrant.") :]

tags = {
SPANDATA.DB_SYSTEM: "qdrant",
SPANDATA.DB_OPERATION: operation,
}

collection_name = getattr(request, "collection_name", None)
if collection_name:
tags["db.qdrant.collection"] = collection_name

return tags

@contextmanager
def _start_span(self, request, tags):
# type: (Message, dict) -> Generator[Span, None, None]
"""
Starts a new Sentry span for the gRPC call and sets the provided tags.
"""
span = sentry_sdk.start_span(
op=OP.DB_QDRANT_GRPC,
name=json.dumps(
_strip_dict(_protobuf_to_dict(request), _DISALLOWED_PROTO_FIELDS),
default=str,
),
origin=_IDENTIFIER,
)
try:
for tag, value in tags.items():
span.set_data(tag, value)
span.set_tag(tag, value)
yield span
finally:
span.finish()


class ClientInterceptor(GenericClientInterceptor, grpc.UnaryUnaryClientInterceptor):
def intercept_unary_unary(self, continuation, client_call_details, request):
# type: (Callable[[ClientCallDetails, Message], UnaryUnaryCall], ClientCallDetails, Message) -> UnaryUnaryCall
"""
Intercepts synchronous unary-unary gRPC calls to add Sentry tracing.
"""
method = client_call_details.method

if not self._should_intercept(method):
return continuation(client_call_details, request)

tags = self._prepare_tags(method, request)
with self._start_span(request, tags) as span:
span.set_data("db.api", "gRPC")
updated_call_details = self._update_client_call_details_metadata_from_scope(
client_call_details
)

response = continuation(updated_call_details, request)
span.set_data("code", response.code().name)

return response


class AsyncClientInterceptor(
GenericClientInterceptor, grpc.aio.UnaryUnaryClientInterceptor
):
async def intercept_unary_unary(self, continuation, client_call_details, request):
# type: (Callable[[ClientCallDetails, Message], UnaryUnaryCall], ClientCallDetails, Message) -> UnaryUnaryCall
"""
Intercepts asynchronous unary-unary gRPC calls to add Sentry tracing.
"""
method = client_call_details.method

if not self._should_intercept(method):
return await continuation(client_call_details, request)

tags = self._prepare_tags(method, request)

with self._start_span(request, tags) as span:
span.set_data("db.api", "gRPC")
updated_call_details = self._update_client_call_details_metadata_from_scope(
client_call_details
)

response = await continuation(updated_call_details, request)

status_code = await response.code()
span.set_data("code", status_code.name)

return response


def _protobuf_to_dict(message, prefix=""):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does GRPC not provide a built-in way to do this?

# type: (Message, str) -> dict
"""
Recursively converts a protobuf message to a dictionary, excluding unset fields.
Args:
message (Message): The protobuf message instance.
prefix (str): The prefix for field names to indicate nesting.
Returns:
dict: A dictionary representation of the protobuf message with only set fields.
"""
result = {}

for field in message.DESCRIPTOR.fields:
field_name = field.name

full_field_name = f"{prefix}.{field_name}" if prefix else field_name

# determine if the field is set
if field.type == FieldDescriptor.TYPE_MESSAGE:
if field.label == FieldDescriptor.LABEL_REPEATED:
is_set = len(getattr(message, field_name)) > 0
elif (
field.message_type.has_options
and field.message_type.GetOptions().map_entry
):
is_set = len(getattr(message, field_name)) > 0
else:
is_set = message.HasField(field_name)
else:
if field.label == FieldDescriptor.LABEL_REPEATED:
is_set = len(getattr(message, field_name)) > 0
else:
# for scalar fields, check presence if possible
try:
is_set = message.HasField(field_name)
except ValueError:
# HasField not available (e.g., proto3 without optional)
# fallback: consider non-default values as set
default_value = field.default_value
value = getattr(message, field_name)
is_set = value != default_value

if not is_set:
# field is either not set or has a default value
continue

value = getattr(message, field_name)

if field.type == FieldDescriptor.TYPE_MESSAGE:
if field.label == FieldDescriptor.LABEL_REPEATED:
# repeated message fields: list of dicts
result[field_name] = [
_protobuf_to_dict(item, prefix=full_field_name) for item in value
]
elif (
field.message_type.has_options
and field.message_type.GetOptions().map_entry
):
# map dict fields
result[field_name] = {key: val for key, val in value.items()}
else:
# single nested message
result[field_name] = _protobuf_to_dict(value, prefix=full_field_name)
elif field.label == FieldDescriptor.LABEL_REPEATED:
# repeated scalar fields
result[field_name] = list(value)
else:
# scalar field
result[field_name] = value

return result


def _get_forbidden_field_placeholder(value, depth=0, max_depth=5):
# type: (Any, int, int) -> str
"""
Generates a placeholder string based on the type of the input value.
Args:
value (Any): The value to generate a placeholder for.
depth (int): Current recursion depth.
max_depth (int): Maximum recursion depth to prevent excessive recursion.
Returns:
str: A placeholder string representing the type of the input value.
"""
if depth > max_depth:
return "..."

if isinstance(value, bool) or isinstance(value, str):
return "%s"
elif isinstance(value, int):
return "%d"
elif isinstance(value, (float, Decimal)):
return "%f"
elif isinstance(value, bytes):
return "b'...'"
elif value is None:
return "null"
elif isinstance(value, list):
if not value:
return "[]"
# handle heterogeneous lists by representing each element
placeholders = [
_get_forbidden_field_placeholder(item, depth + 1, max_depth)
for item in value
]
max_items = 3
if len(placeholders) > max_items:
placeholders = placeholders[:max_items] + ["..."]
return f"[{', '.join(placeholders)}]"
elif isinstance(value, tuple):
if not value:
return "()"
placeholders = [
_get_forbidden_field_placeholder(item, depth + 1, max_depth)
for item in value
]
max_items = 3
if len(placeholders) > max_items:
placeholders = placeholders[:max_items] + ["..."]
return f"({', '.join(placeholders)})"
elif isinstance(value, set):
if not value:
return "set()"
placeholders = [
_get_forbidden_field_placeholder(item, depth + 1, max_depth)
for item in sorted(value, key=lambda x: str(x))
]
max_items = 3
if len(placeholders) > max_items:
placeholders = placeholders[:max_items] + ["..."]
return f"{{{', '.join(placeholders)}}}"
elif isinstance(value, dict):
if not value:
return "{}"
# represent keys and values
placeholders = []
max_items = 3
for i, (k, v) in enumerate(value.items()):
if i >= max_items:
placeholders.append("...")
break
key_placeholder = _get_forbidden_field_placeholder(k, depth + 1, max_depth)
value_placeholder = _get_forbidden_field_placeholder(
v, depth + 1, max_depth
)
placeholders.append(f"{key_placeholder}: {value_placeholder}")
return f"{{{', '.join(placeholders)}}}"
elif isinstance(value, complex):
return "(%f+%fj)" % (value.real, value.imag)
else:
# just use the class name at this point
return f"<{type(value).__name__}>"


def _strip_dict(data, disallowed_fields):
# type: (dict, set) -> dict
for k, v in data.items():
if isinstance(v, dict):
data[k] = _strip_dict(v, disallowed_fields)
elif k in disallowed_fields:
data[k] = _get_forbidden_field_placeholder(v)
elif isinstance(v, list):
data[k] = [
_strip_dict(item, disallowed_fields) if isinstance(item, dict) else item
for item in v
]
return data