From c3d7aca7ec7e0292d9906f327cf8cf0190c1c46d Mon Sep 17 00:00:00 2001 From: Maria Grimaldi Date: Thu, 30 Jan 2025 10:18:40 +0100 Subject: [PATCH] feat: [FC-0074] add support for annotated python dicts as avro map type (#433) Enable Python dicts to be mapped to Avro Map type for schema generation, expanding support for event payloads. Unlike the previous approach of mapping dicts to records, this method prevents conflicts with data attributes and avoids errors when dictionary contents (not type) are unknown. --- CHANGELOG.rst | 7 + openedx_events/__init__.py | 2 +- openedx_events/event_bus/avro/deserializer.py | 17 ++- openedx_events/event_bus/avro/schema.py | 21 ++- .../event_bus/avro/tests/test_avro.py | 24 +++- .../event_bus/avro/tests/test_deserializer.py | 136 ++++++++++++++---- .../event_bus/avro/tests/test_schema.py | 28 +++- .../event_bus/avro/tests/test_utilities.py | 7 + openedx_events/event_bus/avro/types.py | 2 +- 9 files changed, 208 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 578c9db2..d13b6d34 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,13 @@ Change Log Unreleased __________ +[9.16.0] - 2025-01-30 +--------------------- + +Added +~~~~~ + +* Added support for annotated Python dictionaries as Avro Map type. [9.15.2] - 2025-01-16 --------------------- diff --git a/openedx_events/__init__.py b/openedx_events/__init__.py index ee77bd4a..5643ee00 100644 --- a/openedx_events/__init__.py +++ b/openedx_events/__init__.py @@ -5,4 +5,4 @@ more information about the project. """ -__version__ = "9.15.2" +__version__ = "9.16.0" diff --git a/openedx_events/event_bus/avro/deserializer.py b/openedx_events/event_bus/avro/deserializer.py index 9bf07cff..0c4f0a63 100644 --- a/openedx_events/event_bus/avro/deserializer.py +++ b/openedx_events/event_bus/avro/deserializer.py @@ -42,16 +42,27 @@ def _deserialized_avro_record_dict_to_object(data: dict, data_type, deserializer elif data_type in PYTHON_TYPE_TO_AVRO_MAPPING: return data elif data_type_origin == list: - # returns types of list contents - # if data_type == List[int], arg_data_type = (int,) + # Returns types of list contents. + # Example: if data_type == List[int], arg_data_type = (int,) arg_data_type = get_args(data_type) if not arg_data_type: raise TypeError( "List without annotation type is not supported. The argument should be a type, for eg., List[int]" ) - # check whether list items type is in basic types. + # Check whether list items type is in basic types. if arg_data_type[0] in SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING: return data + elif data_type_origin == dict: + # Returns types of dict contents. + # Example: if data_type == Dict[str, int], arg_data_type = (str, int) + arg_data_type = get_args(data_type) + if not arg_data_type: + raise TypeError( + "Dict without annotation type is not supported. The argument should be a type, for eg., Dict[str, int]" + ) + # Check whether dict items type is in basic types. + if all(arg in SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING for arg in arg_data_type): + return data elif hasattr(data_type, "__attrs_attrs__"): transformed = {} for attribute in data_type.__attrs_attrs__: diff --git a/openedx_events/event_bus/avro/schema.py b/openedx_events/event_bus/avro/schema.py index 5d23d0de..fa508f45 100644 --- a/openedx_events/event_bus/avro/schema.py +++ b/openedx_events/event_bus/avro/schema.py @@ -69,14 +69,14 @@ def _create_avro_field_definition(data_key, data_type, previously_seen_types, field["type"] = field_type # Case 2: data_type is a simple type that can be converted directly to an Avro type elif data_type in PYTHON_TYPE_TO_AVRO_MAPPING: - if PYTHON_TYPE_TO_AVRO_MAPPING[data_type] in ["record", "array"]: + if PYTHON_TYPE_TO_AVRO_MAPPING[data_type] in ["map", "array"]: # pylint: disable-next=broad-exception-raised raise Exception("Unable to generate Avro schema for dict or array fields without annotation types.") avro_type = PYTHON_TYPE_TO_AVRO_MAPPING[data_type] field["type"] = avro_type elif data_type_origin == list: - # returns types of list contents - # if data_type == List[int], arg_data_type = (int,) + # Returns types of list contents. + # Example: if data_type == List[int], arg_data_type = (int,) arg_data_type = get_args(data_type) if not arg_data_type: raise TypeError( @@ -89,6 +89,21 @@ def _create_avro_field_definition(data_key, data_type, previously_seen_types, f" {set(SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.keys())}" ) field["type"] = {"type": PYTHON_TYPE_TO_AVRO_MAPPING[data_type_origin], "items": avro_type} + elif data_type_origin == dict: + # Returns types of dict contents. + # Example: if data_type == Dict[str, int], arg_data_type = (str, int) + arg_data_type = get_args(data_type) + if not arg_data_type: + raise TypeError( + "Dict without annotation type is not supported. The argument should be a type, for eg., Dict[str, int]" + ) + avro_type = SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.get(arg_data_type[1]) + if avro_type is None: + raise TypeError( + "Only following types are supported for dict arguments:" + f" {set(SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.keys())}" + ) + field["type"] = {"type": PYTHON_TYPE_TO_AVRO_MAPPING[data_type_origin], "values": avro_type} # Case 3: data_type is an attrs class elif hasattr(data_type, "__attrs_attrs__"): # Inner Attrs Class diff --git a/openedx_events/event_bus/avro/tests/test_avro.py b/openedx_events/event_bus/avro/tests/test_avro.py index 87634067..67c806bd 100644 --- a/openedx_events/event_bus/avro/tests/test_avro.py +++ b/openedx_events/event_bus/avro/tests/test_avro.py @@ -2,7 +2,7 @@ import io import os from datetime import datetime -from typing import List +from typing import List, Union from unittest import TestCase from uuid import UUID, uuid4 @@ -43,6 +43,7 @@ def generate_test_data_for_schema(schema): # pragma: no cover 'string': "default", 'double': 1.0, 'null': None, + 'map': {'key': 'value'}, } def get_default_value_or_raise(schema_field_type): @@ -71,6 +72,9 @@ def get_default_value_or_raise(schema_field_type): elif sub_field_type == "record": # if we're dealing with a record, recurse into the record data_dict.update({key: generate_test_data_for_schema(field_type)}) + elif sub_field_type == "map": + # if we're dealing with a map, "values" will be the type of values in the map + data_dict.update({key: {"key": get_default_value_or_raise(field_type["values"])}}) else: raise Exception(f"Unsupported type {field_type}") # pylint: disable=broad-exception-raised @@ -112,6 +116,24 @@ def generate_test_event_data_for_data_type(data_type): # pragma: no cover datetime: datetime.now(), CCXLocator: CCXLocator(org='edx', course='DemoX', run='Demo_course', ccx='1'), UUID: uuid4(), + dict[str, str]: {'key': 'value'}, + dict[str, int]: {'key': 1}, + dict[str, float]: {'key': 1.0}, + dict[str, bool]: {'key': True}, + dict[str, CourseKey]: {'key': CourseKey.from_string("course-v1:edX+DemoX.1+2014")}, + dict[str, UsageKey]: {'key': UsageKey.from_string( + "block-v1:edx+DemoX+Demo_course+type@video+block@UaEBjyMjcLW65gaTXggB93WmvoxGAJa0JeHRrDThk", + )}, + dict[str, LibraryLocatorV2]: {'key': LibraryLocatorV2.from_string('lib:MITx:reallyhardproblems')}, + dict[str, LibraryUsageLocatorV2]: { + 'key': LibraryUsageLocatorV2.from_string('lb:MITx:reallyhardproblems:problem:problem1'), + }, + dict[str, List[int]]: {'key': [1, 2, 3]}, + dict[str, List[str]]: {'key': ["hi", "there"]}, + dict[str, dict[str, str]]: {'key': {'key': 'value'}}, + dict[str, dict[str, int]]: {'key': {'key': 1}}, + dict[str, Union[str, int]]: {'key': 'value'}, + dict[str, Union[str, int, float]]: {'key': 1.0}, } data_dict = {} for attribute in data_type.__attrs_attrs__: diff --git a/openedx_events/event_bus/avro/tests/test_deserializer.py b/openedx_events/event_bus/avro/tests/test_deserializer.py index e7037998..f8522b03 100644 --- a/openedx_events/event_bus/avro/tests/test_deserializer.py +++ b/openedx_events/event_bus/avro/tests/test_deserializer.py @@ -1,14 +1,16 @@ """Tests for avro.deserializer""" import json from datetime import datetime -from typing import List +from typing import Dict, List from unittest import TestCase +import ddt from opaque_keys.edx.keys import CourseKey, UsageKey from opaque_keys.edx.locator import LibraryLocatorV2, LibraryUsageLocatorV2 from openedx_events.event_bus.avro.deserializer import AvroSignalDeserializer, deserialize_bytes_to_event_data from openedx_events.event_bus.avro.tests.test_utilities import ( + ComplexAttrs, EventData, NestedAttrsWithDefaults, NestedNonAttrs, @@ -23,6 +25,7 @@ from openedx_events.tests.utils import FreezeSignalCacheMixin +@ddt.ddt class TestAvroSignalDeserializerCache(TestCase, FreezeSignalCacheMixin): """Test AvroSignalDeserializer""" @@ -30,36 +33,66 @@ def setUp(self) -> None: super().setUp() self.maxDiff = None - def test_schema_string(self): + @ddt.data( + ( + SimpleAttrs, + { + "name": "CloudEvent", + "type": "record", + "doc": "Avro Event Format for CloudEvents created with openedx_events/schema", + "namespace": "simple.signal", + "fields": [ + { + "name": "data", + "type": { + "name": "SimpleAttrs", + "type": "record", + "fields": [ + {"name": "boolean_field", "type": "boolean"}, + {"name": "int_field", "type": "long"}, + {"name": "float_field", "type": "double"}, + {"name": "bytes_field", "type": "bytes"}, + {"name": "string_field", "type": "string"}, + ], + }, + }, + ], + } + ), + ( + ComplexAttrs, + { + "name": "CloudEvent", + "type": "record", + "doc": "Avro Event Format for CloudEvents created with openedx_events/schema", + "namespace": "simple.signal", + "fields": [ + { + "name": "data", + "type": { + "name": "ComplexAttrs", + "type": "record", + "fields": [ + {"name": "list_field", "type": {"type": "array", "items": "long"}}, + {"name": "dict_field", "type": {"type": "map", "values": "long"}}, + ], + }, + }, + ], + } + ) + ) + @ddt.unpack + def test_schema_string(self, data_cls, expected_schema): """ Test JSON round-trip; schema creation is tested more fully in test_schema.py. """ SIGNAL = create_simple_signal({ - "data": SimpleAttrs + "data": data_cls }) + actual_schema = json.loads(AvroSignalDeserializer(SIGNAL).schema_string()) - expected_schema = { - 'name': 'CloudEvent', - 'type': 'record', - 'doc': 'Avro Event Format for CloudEvents created with openedx_events/schema', - 'namespace': 'simple.signal', - 'fields': [ - { - 'name': 'data', - 'type': { - 'name': 'SimpleAttrs', - 'type': 'record', - 'fields': [ - {'name': 'boolean_field', 'type': 'boolean'}, - {'name': 'int_field', 'type': 'long'}, - {'name': 'float_field', 'type': 'double'}, - {'name': 'bytes_field', 'type': 'bytes'}, - {'name': 'string_field', 'type': 'string'}, - ] - } - } - ] - } + assert actual_schema == expected_schema def test_convert_dict_to_event_data(self): @@ -233,6 +266,59 @@ def test_deserialization_of_list_without_annotation(self): with self.assertRaises(TypeError): deserializer.from_dict(initial_dict) + def test_deserialization_of_dict_with_annotation(self): + """ + Check that deserialization works as expected when dict data is annotated. + """ + DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + initial_dict = {"dict_input": {"key1": 1, "key2": 3}} + + deserializer = AvroSignalDeserializer(DICT_SIGNAL) + event_data = deserializer.from_dict(initial_dict) + expected_event_data = {"key1": 1, "key2": 3} + test_data = event_data["dict_input"] + + self.assertIsInstance(test_data, dict) + self.assertEqual(test_data, expected_event_data) + + def test_deserialization_of_dict_without_annotation(self): + """ + Check that deserialization raises error when dict data is not annotated. + + Create dummy signal to bypass schema check while initializing deserializer. Then, + update signal with incomplete type info to test whether correct exceptions are raised while deserializing data. + """ + SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + DICT_SIGNAL = create_simple_signal({"dict_input": Dict}) + initial_dict = {"dict_input": {"key1": 1, "key2": 3}} + + deserializer = AvroSignalDeserializer(SIGNAL) + deserializer.signal = DICT_SIGNAL + + with self.assertRaises(TypeError): + deserializer.from_dict(initial_dict) + + def test_deserialization_of_dict_with_complex_types_fails(self): + SIGNAL = create_simple_signal({"dict_input": Dict[str, list]}) + with self.assertRaises(TypeError): + AvroSignalDeserializer(SIGNAL) + initial_dict = {"dict_input": {"key1": [1, 3], "key2": [4, 5]}} + # create dummy signal to bypass schema check while initializing deserializer + # This allows us to test whether correct exceptions are raised while deserializing data + DUMMY_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + deserializer = AvroSignalDeserializer(DUMMY_SIGNAL) + # Update signal with incorrect type info + deserializer.signal = SIGNAL + with self.assertRaises(TypeError): + deserializer.from_dict(initial_dict) + + def test_deserialization_of_dicts_with_keys_of_complex_types_fails(self): + SIGNAL = create_simple_signal({"dict_input": Dict[CourseKey, int]}) + deserializer = AvroSignalDeserializer(SIGNAL) + initial_dict = {"dict_input": {CourseKey.from_string("course-v1:edX+DemoX.1+2014"): 1}} + with self.assertRaises(TypeError): + deserializer.from_dict(initial_dict) + def test_deserialization_of_nested_list_fails(self): """ Check that deserialization raises error when nested list data is passed. diff --git a/openedx_events/event_bus/avro/tests/test_schema.py b/openedx_events/event_bus/avro/tests/test_schema.py index 8ad6245b..a3410643 100644 --- a/openedx_events/event_bus/avro/tests/test_schema.py +++ b/openedx_events/event_bus/avro/tests/test_schema.py @@ -1,7 +1,7 @@ """ Tests for event_bus.avro.schema module """ -from typing import List +from typing import Dict, List from unittest import TestCase from openedx_events.event_bus.avro.schema import schema_from_signal @@ -245,8 +245,9 @@ class UnextendedClass: def test_throw_exception_to_list_or_dict_types_without_annotation(self): LIST_SIGNAL = create_simple_signal({"list_input": list}) - DICT_SIGNAL = create_simple_signal({"list_input": dict}) + DICT_SIGNAL = create_simple_signal({"dict_input": dict}) LIST_WITHOUT_ANNOTATION_SIGNAL = create_simple_signal({"list_input": List}) + DICT_WITHOUT_ANNOTATION_SIGNAL = create_simple_signal({"dict_input": Dict}) with self.assertRaises(Exception): schema_from_signal(LIST_SIGNAL) @@ -256,6 +257,14 @@ def test_throw_exception_to_list_or_dict_types_without_annotation(self): with self.assertRaises(TypeError): schema_from_signal(LIST_WITHOUT_ANNOTATION_SIGNAL) + with self.assertRaises(TypeError): + schema_from_signal(DICT_WITHOUT_ANNOTATION_SIGNAL) + + def test_throw_exception_invalid_dict_annotation(self): + INVALID_DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, NestedAttrsWithDefaults]}) + with self.assertRaises(TypeError): + schema_from_signal(INVALID_DICT_SIGNAL) + def test_list_with_annotation_works(self): LIST_SIGNAL = create_simple_signal({"list_input": List[int]}) expected_dict = { @@ -270,3 +279,18 @@ def test_list_with_annotation_works(self): } schema = schema_from_signal(LIST_SIGNAL) self.assertDictEqual(schema, expected_dict) + + def test_dict_with_annotation_works(self): + DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + expected_dict = { + 'name': 'CloudEvent', + 'type': 'record', + 'doc': 'Avro Event Format for CloudEvents created with openedx_events/schema', + 'namespace': 'simple.signal', + 'fields': [{ + 'name': 'dict_input', + 'type': {'type': 'map', 'values': 'long'}, + }], + } + schema = schema_from_signal(DICT_SIGNAL) + self.assertDictEqual(schema, expected_dict) diff --git a/openedx_events/event_bus/avro/tests/test_utilities.py b/openedx_events/event_bus/avro/tests/test_utilities.py index 1644c3e4..05b1b9c7 100644 --- a/openedx_events/event_bus/avro/tests/test_utilities.py +++ b/openedx_events/event_bus/avro/tests/test_utilities.py @@ -39,6 +39,13 @@ class SimpleAttrs: string_field: str +@attr.s(auto_attribs=True) +class ComplexAttrs: + """Class with all complex type fields""" + list_field: list[int] + dict_field: dict[str, int] + + @attr.s(auto_attribs=True) class SubTestData0: """Subclass for testing nested attrs""" diff --git a/openedx_events/event_bus/avro/types.py b/openedx_events/event_bus/avro/types.py index f3bc2536..b757a899 100644 --- a/openedx_events/event_bus/avro/types.py +++ b/openedx_events/event_bus/avro/types.py @@ -9,6 +9,6 @@ PYTHON_TYPE_TO_AVRO_MAPPING = { **SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING, None: "null", - dict: "record", + dict: "map", list: "array", }