Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kafka_consumer/changelog.d/22020.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support Protobuf messages with schema registry
188 changes: 135 additions & 53 deletions kafka_consumer/datadog_checks/kafka_consumer/kafka_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,49 +630,92 @@ def deserialize_message(
return decoded_value, value_schema_id, None, None


def _read_varint(data):
shift = 0
result = 0
bytes_read = 0

for byte in data:
bytes_read += 1
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
return result, bytes_read
shift += 7

raise ValueError("Incomplete varint")


def _read_protobuf_message_indices(payload):
"""
Read the Confluent Protobuf message indices array.

The Confluent Protobuf wire format includes message indices after the schema ID:
[message_indices_length:varint][message_indices:varint...]

The indices indicate which message type to use from the .proto schema.
For example, [0] = first message, [1] = second message, [0, 0] = nested message.

Args:
payload: bytes after the schema ID

Returns:
tuple: (message_indices list, remaining payload bytes)
"""
array_len, bytes_read = _read_varint(payload)
payload = payload[bytes_read:]

indices = []
for _ in range(array_len):
index, bytes_read = _read_varint(payload)
indices.append(index)
payload = payload[bytes_read:]

return indices, payload


def _deserialize_bytes_maybe_schema_registry(message, message_format, schema, uses_schema_registry):
if not message:
return "", None
if uses_schema_registry:
# When explicitly configured, go straight to schema registry format
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
msg_hex = message[:5].hex() if len(message) >= 5 else message.hex()
raise ValueError(
f"Expected schema registry format (magic byte 0x00 + 4-byte schema ID), "
f"but message is too short or has wrong magic byte: {msg_hex}"
)
schema_id = int.from_bytes(message[1:5], 'big')
message = message[5:] # Skip the magic byte and schema ID bytes
return _deserialize_bytes(message, message_format, schema), schema_id
return _deserialize_bytes(message, message_format, schema, True)
else:
# Fallback behavior: try without schema registry format first, then with it
try:
return _deserialize_bytes(message, message_format, schema), None
except (UnicodeDecodeError, json.JSONDecodeError, ValueError) as e:
# If the message is not valid, it might be a schema registry message, that is prefixed
# with a magic byte and a schema ID.
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
raise e
schema_id = int.from_bytes(message[1:5], 'big')
message = message[5:] # Skip the magic byte and schema ID bytes
return _deserialize_bytes(message, message_format, schema), schema_id


def _deserialize_bytes(message, message_format, schema):
"""Deserialize a message from Kafka. Supports JSON format.
return _deserialize_bytes(message, message_format, schema, False)
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
return _deserialize_bytes(message, message_format, schema, True)


def _deserialize_bytes(message, message_format, schema, uses_schema_registry):
"""Deserialize a message from Kafka.
Args:
message: Raw message bytes from Kafka
message_format: Format of the message (protobuf, avro, json, etc.)
schema: Schema object (type depends on message_format)
uses_schema_registry: Whether message uses schema registry format
Returns:
Decoded message as a string
Tuple of (decoded_message, schema_id) where schema_id is None if not using schema registry
"""
if not message:
return ""
return "", None

schema_id = None
if uses_schema_registry:
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
msg_hex = message[:5].hex() if len(message) >= 5 else message.hex()
raise ValueError(
f"Expected schema registry format (magic byte 0x00 + 4-byte schema ID), "
f"but message is too short or has wrong magic byte: {msg_hex}"
)
schema_id = int.from_bytes(message[1:5], 'big')
message = message[5:]

if message_format == 'protobuf':
return _deserialize_protobuf(message, schema)
return _deserialize_protobuf(message, schema, uses_schema_registry), schema_id
elif message_format == 'avro':
return _deserialize_avro(message, schema)
return _deserialize_avro(message, schema), schema_id
else:
return _deserialize_json(message)
return _deserialize_json(message), schema_id


def _deserialize_json(message):
Expand All @@ -681,10 +724,58 @@ def _deserialize_json(message):
return decoded


def _deserialize_protobuf(message, schema):
"""Deserialize a Protobuf message using google.protobuf with strict validation."""
def _get_protobuf_message_class(schema_info, message_indices):
"""Get the protobuf message class based on schema info and message indices.

Args:
schema_info: Tuple of (descriptor_pool, file_descriptor_set)
message_indices: List of indices (e.g., [0], [1], [2, 0] for nested)

Returns:
Message class for the specified type
"""
pool, descriptor_set = schema_info

# First index is the message type in the file
file_descriptor = descriptor_set.file[0]
message_descriptor_proto = file_descriptor.message_type[message_indices[0]]

package = file_descriptor.package
name_parts = [message_descriptor_proto.name]

# Handle nested messages if there are more indices
current_proto = message_descriptor_proto
for idx in message_indices[1:]:
current_proto = current_proto.nested_type[idx]
name_parts.append(current_proto.name)

if package:
full_name = f"{package}.{'.'.join(name_parts)}"
else:
full_name = '.'.join(name_parts)

message_descriptor = pool.FindMessageTypeByName(full_name)
return message_factory.GetMessageClass(message_descriptor)


def _deserialize_protobuf(message, schema_info, uses_schema_registry):
"""Deserialize a Protobuf message using google.protobuf with strict validation.

Args:
message: Raw protobuf bytes
schema_info: Tuple of (descriptor_pool, file_descriptor_set) from build_protobuf_schema
uses_schema_registry: Whether to extract Confluent message indices from the message
"""
try:
bytes_consumed = schema.ParseFromString(message)
if uses_schema_registry:
message_indices, message = _read_protobuf_message_indices(message)
else:
message_indices = [0]

message_class = _get_protobuf_message_class(schema_info, message_indices)
schema_instance = message_class()

bytes_consumed = schema_instance.ParseFromString(message)

# Check if all bytes were consumed (strict validation)
if bytes_consumed != len(message):
Expand All @@ -693,7 +784,7 @@ def _deserialize_protobuf(message, schema):
f"Read {bytes_consumed} bytes, but message has {len(message)} bytes. "
)

return MessageToJson(schema)
return MessageToJson(schema_instance)
except Exception as e:
raise ValueError(f"Failed to deserialize Protobuf message: {e}")

Expand Down Expand Up @@ -740,6 +831,17 @@ def build_avro_schema(schema_str):


def build_protobuf_schema(schema_str):
"""Build a Protobuf schema from a base64-encoded FileDescriptorSet.

Returns a tuple of (descriptor_pool, file_descriptor_set) that can be used
to dynamically select and instantiate message types based on message indices.

Args:
schema_str: Base64-encoded FileDescriptorSet

Returns:
tuple: (DescriptorPool, FileDescriptorSet)
"""
# schema is encoded in base64, decode it before passing it to ParseFromString
schema_str = base64.b64decode(schema_str)
descriptor_set = descriptor_pb2.FileDescriptorSet()
Expand All @@ -750,24 +852,4 @@ def build_protobuf_schema(schema_str):
for fd_proto in descriptor_set.file:
pool.Add(fd_proto)

# Pick the first message type from the first file descriptor
first_fd = descriptor_set.file[0]
# The file descriptor contains a list of message types (DescriptorProto)
first_message_proto = first_fd.message_type[0]

# The fully qualified name includes the package name + message name
package = first_fd.package
message_name = first_message_proto.name
if package:
full_name = f"{package}.{message_name}"
else:
full_name = message_name
# # Get the message descriptor
message_descriptor = pool.FindMessageTypeByName(full_name)
# Create a dynamic message class
schema = message_factory.GetMessageClass(message_descriptor)()

if schema is None:
raise ValueError("Protobuf schema cannot be None")

return schema
return (pool, descriptor_set)
56 changes: 48 additions & 8 deletions kafka_consumer/tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from datadog_checks.kafka_consumer.kafka_consumer import (
DATA_STREAMS_MESSAGES_CACHE_KEY,
_get_interpolated_timestamp,
_get_protobuf_message_class,
build_avro_schema,
build_protobuf_schema,
build_schema,
Expand Down Expand Up @@ -902,8 +903,12 @@ def test_schema_registry_explicit_configuration():
assert result == (None, None, None, None), "Protobuf should fail when uses_schema_registry=True but no SR format"

# Valid Protobuf message WITH schema registry format
# Confluent Protobuf wire format:
# [magic_byte][schema_id:4bytes][array_length:varint][index:varint][protobuf_payload]
protobuf_message_with_sr = (
b'\x00\x00\x00\x01\x5e'
b'\x00\x00\x00\x01\x5e' # magic byte (0x00) + schema ID 350 (0x0000015e)
b'\x01' # message indices array length = 1
b'\x00' # message index = 0
b'\x08\xe8\xba\xb2\xeb\xd1\x9c\x02\x12\x1b\x54\x68\x65\x20\x47\x6f\x20\x50\x72\x6f\x67\x72\x61\x6d\x6d\x69\x6e\x67\x20\x4c\x61\x6e\x67\x75\x61\x67\x65'
b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e'
)
Expand Down Expand Up @@ -934,6 +939,40 @@ def test_schema_registry_explicit_configuration():
assert result[3] is None, "Key schema ID should be None when key fails"


def test_protobuf_message_indices_with_schema_registry():
"""Test Confluent Protobuf wire format with different message indices."""
key = b'{"test": "key"}'

# Schema with multiple message types and nested type
# message Book { int64 isbn = 1; string title = 2; }
# message Author { string name = 1; int32 age = 2; }
# message Library { message Section { string name = 1; } string name = 1; }
protobuf_schema = (
'CpMBCgxzY2hlbWEucHJvdG8SC2NvbS5leGFtcGxlIh8KBEJvb2sSCgoEaXNibhgBKAMSCwoFdGl0bGUY'
'AigJIh8KBkF1dGhvchIKCgRuYW1lGAEoCRIJCgNhZ2UYAigFIiwKB0xpYnJhcnkSCgoEbmFtZRgBKAka'
'FQoHU2VjdGlvbhIKCgRuYW1lGAEoCWIGcHJvdG8z'
)
parsed_schema = build_schema('protobuf', protobuf_schema)

# Test index [0] - Book message
book_payload = bytes.fromhex('08e80712095465737420426f6f6b')
book_msg = b'\x00\x00\x00\x01\x5e\x01\x00' + book_payload
result = deserialize_message(MockedMessage(book_msg, key), 'protobuf', parsed_schema, True, 'json', '', False)
assert result[0] and 'Test Book' in result[0]

# Test index [1] - Author message
author_payload = bytes.fromhex('0a0a4a616e6520536d697468101e')
author_msg = b'\x00\x00\x00\x01\x5e\x01\x01' + author_payload
result = deserialize_message(MockedMessage(author_msg, key), 'protobuf', parsed_schema, True, 'json', '', False)
assert result[0] and 'Jane Smith' in result[0] and '30' in result[0]

# Test nested [2, 0] - Library.Section message
section_payload = bytes.fromhex('0a0746696374696f6e')
section_msg = b'\x00\x00\x00\x01\x5e\x02\x02\x00' + section_payload
result = deserialize_message(MockedMessage(section_msg, key), 'protobuf', parsed_schema, True, 'json', '', False)
assert result[0] and 'Fiction' in result[0]


def mocked_time():
return 400

Expand Down Expand Up @@ -1200,11 +1239,11 @@ def test_build_schema():
'EhQKBXRpdGxlGAIgASgJUgV0aXRsZRIWCgZhdXRob3IYAyABKAlSBmF1dGhvcmIGcHJvdG8z'
)
protobuf_result = build_schema('protobuf', valid_protobuf_schema)
assert protobuf_result is not None
# The result should be a protobuf message class instance
assert hasattr(protobuf_result, 'isbn')
assert hasattr(protobuf_result, 'title')
assert hasattr(protobuf_result, 'author')
message_class = _get_protobuf_message_class(protobuf_result, [0])
message_instance = message_class()
assert hasattr(message_instance, 'isbn')
assert hasattr(message_instance, 'title')
assert hasattr(message_instance, 'author')

# Test unknown format
assert build_schema('unknown_format', 'some_schema') is None
Expand Down Expand Up @@ -1232,14 +1271,15 @@ def test_build_schema_error_cases():
with pytest.raises(DecodeError): # Will be a protobuf DecodeError
build_schema('protobuf', 'SGVsbG8gV29ybGQ=') # "Hello World" in base64

# Valid base64 but empty schema (should cause IndexError)
# Valid base64 but empty schema - should fail when trying to access message types
# Create a minimal but empty FileDescriptorSet
empty_descriptor = descriptor_pb2.FileDescriptorSet()
empty_descriptor_bytes = empty_descriptor.SerializeToString()
empty_descriptor_b64 = base64.b64encode(empty_descriptor_bytes).decode('utf-8')

result = build_schema('protobuf', empty_descriptor_b64)
with pytest.raises(IndexError): # Should fail when trying to access file[0]
build_schema('protobuf', empty_descriptor_b64)
_get_protobuf_message_class(result, [0])


def test_build_schema_none_handling():
Expand Down
Loading