Skip to content

Commit 44a9723

Browse files
kafka_consumer: Support Protobuf messages with schema registry (#22020)
* kafka_consumer: Support Protobuf messages with schema registry * Update kafka_consumer/changelog.d/22020.fixed Co-authored-by: Steven Yuen <[email protected]> --------- Co-authored-by: Steven Yuen <[email protected]>
1 parent 31e492b commit 44a9723

File tree

3 files changed

+184
-61
lines changed

3 files changed

+184
-61
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support Protobuf messages with schema registry

kafka_consumer/datadog_checks/kafka_consumer/kafka_consumer.py

Lines changed: 135 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -630,49 +630,92 @@ def deserialize_message(
630630
return decoded_value, value_schema_id, None, None
631631

632632

633+
def _read_varint(data):
634+
shift = 0
635+
result = 0
636+
bytes_read = 0
637+
638+
for byte in data:
639+
bytes_read += 1
640+
result |= (byte & 0x7F) << shift
641+
if (byte & 0x80) == 0:
642+
return result, bytes_read
643+
shift += 7
644+
645+
raise ValueError("Incomplete varint")
646+
647+
648+
def _read_protobuf_message_indices(payload):
649+
"""
650+
Read the Confluent Protobuf message indices array.
651+
652+
The Confluent Protobuf wire format includes message indices after the schema ID:
653+
[message_indices_length:varint][message_indices:varint...]
654+
655+
The indices indicate which message type to use from the .proto schema.
656+
For example, [0] = first message, [1] = second message, [0, 0] = nested message.
657+
658+
Args:
659+
payload: bytes after the schema ID
660+
661+
Returns:
662+
tuple: (message_indices list, remaining payload bytes)
663+
"""
664+
array_len, bytes_read = _read_varint(payload)
665+
payload = payload[bytes_read:]
666+
667+
indices = []
668+
for _ in range(array_len):
669+
index, bytes_read = _read_varint(payload)
670+
indices.append(index)
671+
payload = payload[bytes_read:]
672+
673+
return indices, payload
674+
675+
633676
def _deserialize_bytes_maybe_schema_registry(message, message_format, schema, uses_schema_registry):
634677
if not message:
635678
return "", None
636679
if uses_schema_registry:
637-
# When explicitly configured, go straight to schema registry format
638-
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
639-
msg_hex = message[:5].hex() if len(message) >= 5 else message.hex()
640-
raise ValueError(
641-
f"Expected schema registry format (magic byte 0x00 + 4-byte schema ID), "
642-
f"but message is too short or has wrong magic byte: {msg_hex}"
643-
)
644-
schema_id = int.from_bytes(message[1:5], 'big')
645-
message = message[5:] # Skip the magic byte and schema ID bytes
646-
return _deserialize_bytes(message, message_format, schema), schema_id
680+
return _deserialize_bytes(message, message_format, schema, True)
647681
else:
648682
# Fallback behavior: try without schema registry format first, then with it
649683
try:
650-
return _deserialize_bytes(message, message_format, schema), None
651-
except (UnicodeDecodeError, json.JSONDecodeError, ValueError) as e:
652-
# If the message is not valid, it might be a schema registry message, that is prefixed
653-
# with a magic byte and a schema ID.
654-
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
655-
raise e
656-
schema_id = int.from_bytes(message[1:5], 'big')
657-
message = message[5:] # Skip the magic byte and schema ID bytes
658-
return _deserialize_bytes(message, message_format, schema), schema_id
659-
660-
661-
def _deserialize_bytes(message, message_format, schema):
662-
"""Deserialize a message from Kafka. Supports JSON format.
684+
return _deserialize_bytes(message, message_format, schema, False)
685+
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
686+
return _deserialize_bytes(message, message_format, schema, True)
687+
688+
689+
def _deserialize_bytes(message, message_format, schema, uses_schema_registry):
690+
"""Deserialize a message from Kafka.
663691
Args:
664692
message: Raw message bytes from Kafka
693+
message_format: Format of the message (protobuf, avro, json, etc.)
694+
schema: Schema object (type depends on message_format)
695+
uses_schema_registry: Whether message uses schema registry format
665696
Returns:
666-
Decoded message as a string
697+
Tuple of (decoded_message, schema_id) where schema_id is None if not using schema registry
667698
"""
668699
if not message:
669-
return ""
700+
return "", None
701+
702+
schema_id = None
703+
if uses_schema_registry:
704+
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
705+
msg_hex = message[:5].hex() if len(message) >= 5 else message.hex()
706+
raise ValueError(
707+
f"Expected schema registry format (magic byte 0x00 + 4-byte schema ID), "
708+
f"but message is too short or has wrong magic byte: {msg_hex}"
709+
)
710+
schema_id = int.from_bytes(message[1:5], 'big')
711+
message = message[5:]
712+
670713
if message_format == 'protobuf':
671-
return _deserialize_protobuf(message, schema)
714+
return _deserialize_protobuf(message, schema, uses_schema_registry), schema_id
672715
elif message_format == 'avro':
673-
return _deserialize_avro(message, schema)
716+
return _deserialize_avro(message, schema), schema_id
674717
else:
675-
return _deserialize_json(message)
718+
return _deserialize_json(message), schema_id
676719

677720

678721
def _deserialize_json(message):
@@ -681,10 +724,58 @@ def _deserialize_json(message):
681724
return decoded
682725

683726

684-
def _deserialize_protobuf(message, schema):
685-
"""Deserialize a Protobuf message using google.protobuf with strict validation."""
727+
def _get_protobuf_message_class(schema_info, message_indices):
728+
"""Get the protobuf message class based on schema info and message indices.
729+
730+
Args:
731+
schema_info: Tuple of (descriptor_pool, file_descriptor_set)
732+
message_indices: List of indices (e.g., [0], [1], [2, 0] for nested)
733+
734+
Returns:
735+
Message class for the specified type
736+
"""
737+
pool, descriptor_set = schema_info
738+
739+
# First index is the message type in the file
740+
file_descriptor = descriptor_set.file[0]
741+
message_descriptor_proto = file_descriptor.message_type[message_indices[0]]
742+
743+
package = file_descriptor.package
744+
name_parts = [message_descriptor_proto.name]
745+
746+
# Handle nested messages if there are more indices
747+
current_proto = message_descriptor_proto
748+
for idx in message_indices[1:]:
749+
current_proto = current_proto.nested_type[idx]
750+
name_parts.append(current_proto.name)
751+
752+
if package:
753+
full_name = f"{package}.{'.'.join(name_parts)}"
754+
else:
755+
full_name = '.'.join(name_parts)
756+
757+
message_descriptor = pool.FindMessageTypeByName(full_name)
758+
return message_factory.GetMessageClass(message_descriptor)
759+
760+
761+
def _deserialize_protobuf(message, schema_info, uses_schema_registry):
762+
"""Deserialize a Protobuf message using google.protobuf with strict validation.
763+
764+
Args:
765+
message: Raw protobuf bytes
766+
schema_info: Tuple of (descriptor_pool, file_descriptor_set) from build_protobuf_schema
767+
uses_schema_registry: Whether to extract Confluent message indices from the message
768+
"""
686769
try:
687-
bytes_consumed = schema.ParseFromString(message)
770+
if uses_schema_registry:
771+
message_indices, message = _read_protobuf_message_indices(message)
772+
else:
773+
message_indices = [0]
774+
775+
message_class = _get_protobuf_message_class(schema_info, message_indices)
776+
schema_instance = message_class()
777+
778+
bytes_consumed = schema_instance.ParseFromString(message)
688779

689780
# Check if all bytes were consumed (strict validation)
690781
if bytes_consumed != len(message):
@@ -693,7 +784,7 @@ def _deserialize_protobuf(message, schema):
693784
f"Read {bytes_consumed} bytes, but message has {len(message)} bytes. "
694785
)
695786

696-
return MessageToJson(schema)
787+
return MessageToJson(schema_instance)
697788
except Exception as e:
698789
raise ValueError(f"Failed to deserialize Protobuf message: {e}")
699790

@@ -740,6 +831,17 @@ def build_avro_schema(schema_str):
740831

741832

742833
def build_protobuf_schema(schema_str):
834+
"""Build a Protobuf schema from a base64-encoded FileDescriptorSet.
835+
836+
Returns a tuple of (descriptor_pool, file_descriptor_set) that can be used
837+
to dynamically select and instantiate message types based on message indices.
838+
839+
Args:
840+
schema_str: Base64-encoded FileDescriptorSet
841+
842+
Returns:
843+
tuple: (DescriptorPool, FileDescriptorSet)
844+
"""
743845
# schema is encoded in base64, decode it before passing it to ParseFromString
744846
schema_str = base64.b64decode(schema_str)
745847
descriptor_set = descriptor_pb2.FileDescriptorSet()
@@ -750,24 +852,4 @@ def build_protobuf_schema(schema_str):
750852
for fd_proto in descriptor_set.file:
751853
pool.Add(fd_proto)
752854

753-
# Pick the first message type from the first file descriptor
754-
first_fd = descriptor_set.file[0]
755-
# The file descriptor contains a list of message types (DescriptorProto)
756-
first_message_proto = first_fd.message_type[0]
757-
758-
# The fully qualified name includes the package name + message name
759-
package = first_fd.package
760-
message_name = first_message_proto.name
761-
if package:
762-
full_name = f"{package}.{message_name}"
763-
else:
764-
full_name = message_name
765-
# # Get the message descriptor
766-
message_descriptor = pool.FindMessageTypeByName(full_name)
767-
# Create a dynamic message class
768-
schema = message_factory.GetMessageClass(message_descriptor)()
769-
770-
if schema is None:
771-
raise ValueError("Protobuf schema cannot be None")
772-
773-
return schema
855+
return (pool, descriptor_set)

kafka_consumer/tests/test_unit.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from datadog_checks.kafka_consumer.kafka_consumer import (
1818
DATA_STREAMS_MESSAGES_CACHE_KEY,
1919
_get_interpolated_timestamp,
20+
_get_protobuf_message_class,
2021
build_avro_schema,
2122
build_protobuf_schema,
2223
build_schema,
@@ -902,8 +903,12 @@ def test_schema_registry_explicit_configuration():
902903
assert result == (None, None, None, None), "Protobuf should fail when uses_schema_registry=True but no SR format"
903904

904905
# Valid Protobuf message WITH schema registry format
906+
# Confluent Protobuf wire format:
907+
# [magic_byte][schema_id:4bytes][array_length:varint][index:varint][protobuf_payload]
905908
protobuf_message_with_sr = (
906-
b'\x00\x00\x00\x01\x5e'
909+
b'\x00\x00\x00\x01\x5e' # magic byte (0x00) + schema ID 350 (0x0000015e)
910+
b'\x01' # message indices array length = 1
911+
b'\x00' # message index = 0
907912
b'\x08\xe8\xba\xb2\xeb\xd1\x9c\x02\x12\x1b\x54\x68\x65\x20\x47\x6f\x20\x50\x72\x6f\x67\x72\x61\x6d\x6d\x69\x6e\x67\x20\x4c\x61\x6e\x67\x75\x61\x67\x65'
908913
b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e'
909914
)
@@ -934,6 +939,40 @@ def test_schema_registry_explicit_configuration():
934939
assert result[3] is None, "Key schema ID should be None when key fails"
935940

936941

942+
def test_protobuf_message_indices_with_schema_registry():
943+
"""Test Confluent Protobuf wire format with different message indices."""
944+
key = b'{"test": "key"}'
945+
946+
# Schema with multiple message types and nested type
947+
# message Book { int64 isbn = 1; string title = 2; }
948+
# message Author { string name = 1; int32 age = 2; }
949+
# message Library { message Section { string name = 1; } string name = 1; }
950+
protobuf_schema = (
951+
'CpMBCgxzY2hlbWEucHJvdG8SC2NvbS5leGFtcGxlIh8KBEJvb2sSCgoEaXNibhgBKAMSCwoFdGl0bGUY'
952+
'AigJIh8KBkF1dGhvchIKCgRuYW1lGAEoCRIJCgNhZ2UYAigFIiwKB0xpYnJhcnkSCgoEbmFtZRgBKAka'
953+
'FQoHU2VjdGlvbhIKCgRuYW1lGAEoCWIGcHJvdG8z'
954+
)
955+
parsed_schema = build_schema('protobuf', protobuf_schema)
956+
957+
# Test index [0] - Book message
958+
book_payload = bytes.fromhex('08e80712095465737420426f6f6b')
959+
book_msg = b'\x00\x00\x00\x01\x5e\x01\x00' + book_payload
960+
result = deserialize_message(MockedMessage(book_msg, key), 'protobuf', parsed_schema, True, 'json', '', False)
961+
assert result[0] and 'Test Book' in result[0]
962+
963+
# Test index [1] - Author message
964+
author_payload = bytes.fromhex('0a0a4a616e6520536d697468101e')
965+
author_msg = b'\x00\x00\x00\x01\x5e\x01\x01' + author_payload
966+
result = deserialize_message(MockedMessage(author_msg, key), 'protobuf', parsed_schema, True, 'json', '', False)
967+
assert result[0] and 'Jane Smith' in result[0] and '30' in result[0]
968+
969+
# Test nested [2, 0] - Library.Section message
970+
section_payload = bytes.fromhex('0a0746696374696f6e')
971+
section_msg = b'\x00\x00\x00\x01\x5e\x02\x02\x00' + section_payload
972+
result = deserialize_message(MockedMessage(section_msg, key), 'protobuf', parsed_schema, True, 'json', '', False)
973+
assert result[0] and 'Fiction' in result[0]
974+
975+
937976
def mocked_time():
938977
return 400
939978

@@ -1200,11 +1239,11 @@ def test_build_schema():
12001239
'EhQKBXRpdGxlGAIgASgJUgV0aXRsZRIWCgZhdXRob3IYAyABKAlSBmF1dGhvcmIGcHJvdG8z'
12011240
)
12021241
protobuf_result = build_schema('protobuf', valid_protobuf_schema)
1203-
assert protobuf_result is not None
1204-
# The result should be a protobuf message class instance
1205-
assert hasattr(protobuf_result, 'isbn')
1206-
assert hasattr(protobuf_result, 'title')
1207-
assert hasattr(protobuf_result, 'author')
1242+
message_class = _get_protobuf_message_class(protobuf_result, [0])
1243+
message_instance = message_class()
1244+
assert hasattr(message_instance, 'isbn')
1245+
assert hasattr(message_instance, 'title')
1246+
assert hasattr(message_instance, 'author')
12081247

12091248
# Test unknown format
12101249
assert build_schema('unknown_format', 'some_schema') is None
@@ -1232,14 +1271,15 @@ def test_build_schema_error_cases():
12321271
with pytest.raises(DecodeError): # Will be a protobuf DecodeError
12331272
build_schema('protobuf', 'SGVsbG8gV29ybGQ=') # "Hello World" in base64
12341273

1235-
# Valid base64 but empty schema (should cause IndexError)
1274+
# Valid base64 but empty schema - should fail when trying to access message types
12361275
# Create a minimal but empty FileDescriptorSet
12371276
empty_descriptor = descriptor_pb2.FileDescriptorSet()
12381277
empty_descriptor_bytes = empty_descriptor.SerializeToString()
12391278
empty_descriptor_b64 = base64.b64encode(empty_descriptor_bytes).decode('utf-8')
12401279

1280+
result = build_schema('protobuf', empty_descriptor_b64)
12411281
with pytest.raises(IndexError): # Should fail when trying to access file[0]
1242-
build_schema('protobuf', empty_descriptor_b64)
1282+
_get_protobuf_message_class(result, [0])
12431283

12441284

12451285
def test_build_schema_none_handling():

0 commit comments

Comments
 (0)