diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 578c9db2..d13b6d34 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,13 @@ Change Log Unreleased __________ +[9.16.0] - 2025-01-30 +--------------------- + +Added +~~~~~ + +* Added support for annotated Python dictionaries as Avro Map type. [9.15.2] - 2025-01-16 --------------------- diff --git a/openedx_events/__init__.py b/openedx_events/__init__.py index ee77bd4a..5643ee00 100644 --- a/openedx_events/__init__.py +++ b/openedx_events/__init__.py @@ -5,4 +5,4 @@ more information about the project. """ -__version__ = "9.15.2" +__version__ = "9.16.0" diff --git a/openedx_events/event_bus/avro/deserializer.py b/openedx_events/event_bus/avro/deserializer.py index 9bf07cff..0c4f0a63 100644 --- a/openedx_events/event_bus/avro/deserializer.py +++ b/openedx_events/event_bus/avro/deserializer.py @@ -42,16 +42,27 @@ def _deserialized_avro_record_dict_to_object(data: dict, data_type, deserializer elif data_type in PYTHON_TYPE_TO_AVRO_MAPPING: return data elif data_type_origin == list: - # returns types of list contents - # if data_type == List[int], arg_data_type = (int,) + # Returns types of list contents. + # Example: if data_type == List[int], arg_data_type = (int,) arg_data_type = get_args(data_type) if not arg_data_type: raise TypeError( "List without annotation type is not supported. The argument should be a type, for eg., List[int]" ) - # check whether list items type is in basic types. + # Check whether list items type is in basic types. if arg_data_type[0] in SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING: return data + elif data_type_origin == dict: + # Returns types of dict contents. + # Example: if data_type == Dict[str, int], arg_data_type = (str, int) + arg_data_type = get_args(data_type) + if not arg_data_type: + raise TypeError( + "Dict without annotation type is not supported. The argument should be a type, for eg., Dict[str, int]" + ) + # Check whether dict items type is in basic types. + if all(arg in SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING for arg in arg_data_type): + return data elif hasattr(data_type, "__attrs_attrs__"): transformed = {} for attribute in data_type.__attrs_attrs__: diff --git a/openedx_events/event_bus/avro/schema.py b/openedx_events/event_bus/avro/schema.py index 5d23d0de..fa508f45 100644 --- a/openedx_events/event_bus/avro/schema.py +++ b/openedx_events/event_bus/avro/schema.py @@ -69,14 +69,14 @@ def _create_avro_field_definition(data_key, data_type, previously_seen_types, field["type"] = field_type # Case 2: data_type is a simple type that can be converted directly to an Avro type elif data_type in PYTHON_TYPE_TO_AVRO_MAPPING: - if PYTHON_TYPE_TO_AVRO_MAPPING[data_type] in ["record", "array"]: + if PYTHON_TYPE_TO_AVRO_MAPPING[data_type] in ["map", "array"]: # pylint: disable-next=broad-exception-raised raise Exception("Unable to generate Avro schema for dict or array fields without annotation types.") avro_type = PYTHON_TYPE_TO_AVRO_MAPPING[data_type] field["type"] = avro_type elif data_type_origin == list: - # returns types of list contents - # if data_type == List[int], arg_data_type = (int,) + # Returns types of list contents. + # Example: if data_type == List[int], arg_data_type = (int,) arg_data_type = get_args(data_type) if not arg_data_type: raise TypeError( @@ -89,6 +89,21 @@ def _create_avro_field_definition(data_key, data_type, previously_seen_types, f" {set(SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.keys())}" ) field["type"] = {"type": PYTHON_TYPE_TO_AVRO_MAPPING[data_type_origin], "items": avro_type} + elif data_type_origin == dict: + # Returns types of dict contents. + # Example: if data_type == Dict[str, int], arg_data_type = (str, int) + arg_data_type = get_args(data_type) + if not arg_data_type: + raise TypeError( + "Dict without annotation type is not supported. The argument should be a type, for eg., Dict[str, int]" + ) + avro_type = SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.get(arg_data_type[1]) + if avro_type is None: + raise TypeError( + "Only following types are supported for dict arguments:" + f" {set(SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING.keys())}" + ) + field["type"] = {"type": PYTHON_TYPE_TO_AVRO_MAPPING[data_type_origin], "values": avro_type} # Case 3: data_type is an attrs class elif hasattr(data_type, "__attrs_attrs__"): # Inner Attrs Class diff --git a/openedx_events/event_bus/avro/tests/test_avro.py b/openedx_events/event_bus/avro/tests/test_avro.py index 87634067..67c806bd 100644 --- a/openedx_events/event_bus/avro/tests/test_avro.py +++ b/openedx_events/event_bus/avro/tests/test_avro.py @@ -2,7 +2,7 @@ import io import os from datetime import datetime -from typing import List +from typing import List, Union from unittest import TestCase from uuid import UUID, uuid4 @@ -43,6 +43,7 @@ def generate_test_data_for_schema(schema): # pragma: no cover 'string': "default", 'double': 1.0, 'null': None, + 'map': {'key': 'value'}, } def get_default_value_or_raise(schema_field_type): @@ -71,6 +72,9 @@ def get_default_value_or_raise(schema_field_type): elif sub_field_type == "record": # if we're dealing with a record, recurse into the record data_dict.update({key: generate_test_data_for_schema(field_type)}) + elif sub_field_type == "map": + # if we're dealing with a map, "values" will be the type of values in the map + data_dict.update({key: {"key": get_default_value_or_raise(field_type["values"])}}) else: raise Exception(f"Unsupported type {field_type}") # pylint: disable=broad-exception-raised @@ -112,6 +116,24 @@ def generate_test_event_data_for_data_type(data_type): # pragma: no cover datetime: datetime.now(), CCXLocator: CCXLocator(org='edx', course='DemoX', run='Demo_course', ccx='1'), UUID: uuid4(), + dict[str, str]: {'key': 'value'}, + dict[str, int]: {'key': 1}, + dict[str, float]: {'key': 1.0}, + dict[str, bool]: {'key': True}, + dict[str, CourseKey]: {'key': CourseKey.from_string("course-v1:edX+DemoX.1+2014")}, + dict[str, UsageKey]: {'key': UsageKey.from_string( + "block-v1:edx+DemoX+Demo_course+type@video+block@UaEBjyMjcLW65gaTXggB93WmvoxGAJa0JeHRrDThk", + )}, + dict[str, LibraryLocatorV2]: {'key': LibraryLocatorV2.from_string('lib:MITx:reallyhardproblems')}, + dict[str, LibraryUsageLocatorV2]: { + 'key': LibraryUsageLocatorV2.from_string('lb:MITx:reallyhardproblems:problem:problem1'), + }, + dict[str, List[int]]: {'key': [1, 2, 3]}, + dict[str, List[str]]: {'key': ["hi", "there"]}, + dict[str, dict[str, str]]: {'key': {'key': 'value'}}, + dict[str, dict[str, int]]: {'key': {'key': 1}}, + dict[str, Union[str, int]]: {'key': 'value'}, + dict[str, Union[str, int, float]]: {'key': 1.0}, } data_dict = {} for attribute in data_type.__attrs_attrs__: diff --git a/openedx_events/event_bus/avro/tests/test_deserializer.py b/openedx_events/event_bus/avro/tests/test_deserializer.py index e7037998..f8522b03 100644 --- a/openedx_events/event_bus/avro/tests/test_deserializer.py +++ b/openedx_events/event_bus/avro/tests/test_deserializer.py @@ -1,14 +1,16 @@ """Tests for avro.deserializer""" import json from datetime import datetime -from typing import List +from typing import Dict, List from unittest import TestCase +import ddt from opaque_keys.edx.keys import CourseKey, UsageKey from opaque_keys.edx.locator import LibraryLocatorV2, LibraryUsageLocatorV2 from openedx_events.event_bus.avro.deserializer import AvroSignalDeserializer, deserialize_bytes_to_event_data from openedx_events.event_bus.avro.tests.test_utilities import ( + ComplexAttrs, EventData, NestedAttrsWithDefaults, NestedNonAttrs, @@ -23,6 +25,7 @@ from openedx_events.tests.utils import FreezeSignalCacheMixin +@ddt.ddt class TestAvroSignalDeserializerCache(TestCase, FreezeSignalCacheMixin): """Test AvroSignalDeserializer""" @@ -30,36 +33,66 @@ def setUp(self) -> None: super().setUp() self.maxDiff = None - def test_schema_string(self): + @ddt.data( + ( + SimpleAttrs, + { + "name": "CloudEvent", + "type": "record", + "doc": "Avro Event Format for CloudEvents created with openedx_events/schema", + "namespace": "simple.signal", + "fields": [ + { + "name": "data", + "type": { + "name": "SimpleAttrs", + "type": "record", + "fields": [ + {"name": "boolean_field", "type": "boolean"}, + {"name": "int_field", "type": "long"}, + {"name": "float_field", "type": "double"}, + {"name": "bytes_field", "type": "bytes"}, + {"name": "string_field", "type": "string"}, + ], + }, + }, + ], + } + ), + ( + ComplexAttrs, + { + "name": "CloudEvent", + "type": "record", + "doc": "Avro Event Format for CloudEvents created with openedx_events/schema", + "namespace": "simple.signal", + "fields": [ + { + "name": "data", + "type": { + "name": "ComplexAttrs", + "type": "record", + "fields": [ + {"name": "list_field", "type": {"type": "array", "items": "long"}}, + {"name": "dict_field", "type": {"type": "map", "values": "long"}}, + ], + }, + }, + ], + } + ) + ) + @ddt.unpack + def test_schema_string(self, data_cls, expected_schema): """ Test JSON round-trip; schema creation is tested more fully in test_schema.py. """ SIGNAL = create_simple_signal({ - "data": SimpleAttrs + "data": data_cls }) + actual_schema = json.loads(AvroSignalDeserializer(SIGNAL).schema_string()) - expected_schema = { - 'name': 'CloudEvent', - 'type': 'record', - 'doc': 'Avro Event Format for CloudEvents created with openedx_events/schema', - 'namespace': 'simple.signal', - 'fields': [ - { - 'name': 'data', - 'type': { - 'name': 'SimpleAttrs', - 'type': 'record', - 'fields': [ - {'name': 'boolean_field', 'type': 'boolean'}, - {'name': 'int_field', 'type': 'long'}, - {'name': 'float_field', 'type': 'double'}, - {'name': 'bytes_field', 'type': 'bytes'}, - {'name': 'string_field', 'type': 'string'}, - ] - } - } - ] - } + assert actual_schema == expected_schema def test_convert_dict_to_event_data(self): @@ -233,6 +266,59 @@ def test_deserialization_of_list_without_annotation(self): with self.assertRaises(TypeError): deserializer.from_dict(initial_dict) + def test_deserialization_of_dict_with_annotation(self): + """ + Check that deserialization works as expected when dict data is annotated. + """ + DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + initial_dict = {"dict_input": {"key1": 1, "key2": 3}} + + deserializer = AvroSignalDeserializer(DICT_SIGNAL) + event_data = deserializer.from_dict(initial_dict) + expected_event_data = {"key1": 1, "key2": 3} + test_data = event_data["dict_input"] + + self.assertIsInstance(test_data, dict) + self.assertEqual(test_data, expected_event_data) + + def test_deserialization_of_dict_without_annotation(self): + """ + Check that deserialization raises error when dict data is not annotated. + + Create dummy signal to bypass schema check while initializing deserializer. Then, + update signal with incomplete type info to test whether correct exceptions are raised while deserializing data. + """ + SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + DICT_SIGNAL = create_simple_signal({"dict_input": Dict}) + initial_dict = {"dict_input": {"key1": 1, "key2": 3}} + + deserializer = AvroSignalDeserializer(SIGNAL) + deserializer.signal = DICT_SIGNAL + + with self.assertRaises(TypeError): + deserializer.from_dict(initial_dict) + + def test_deserialization_of_dict_with_complex_types_fails(self): + SIGNAL = create_simple_signal({"dict_input": Dict[str, list]}) + with self.assertRaises(TypeError): + AvroSignalDeserializer(SIGNAL) + initial_dict = {"dict_input": {"key1": [1, 3], "key2": [4, 5]}} + # create dummy signal to bypass schema check while initializing deserializer + # This allows us to test whether correct exceptions are raised while deserializing data + DUMMY_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + deserializer = AvroSignalDeserializer(DUMMY_SIGNAL) + # Update signal with incorrect type info + deserializer.signal = SIGNAL + with self.assertRaises(TypeError): + deserializer.from_dict(initial_dict) + + def test_deserialization_of_dicts_with_keys_of_complex_types_fails(self): + SIGNAL = create_simple_signal({"dict_input": Dict[CourseKey, int]}) + deserializer = AvroSignalDeserializer(SIGNAL) + initial_dict = {"dict_input": {CourseKey.from_string("course-v1:edX+DemoX.1+2014"): 1}} + with self.assertRaises(TypeError): + deserializer.from_dict(initial_dict) + def test_deserialization_of_nested_list_fails(self): """ Check that deserialization raises error when nested list data is passed. diff --git a/openedx_events/event_bus/avro/tests/test_schema.py b/openedx_events/event_bus/avro/tests/test_schema.py index 8ad6245b..a3410643 100644 --- a/openedx_events/event_bus/avro/tests/test_schema.py +++ b/openedx_events/event_bus/avro/tests/test_schema.py @@ -1,7 +1,7 @@ """ Tests for event_bus.avro.schema module """ -from typing import List +from typing import Dict, List from unittest import TestCase from openedx_events.event_bus.avro.schema import schema_from_signal @@ -245,8 +245,9 @@ class UnextendedClass: def test_throw_exception_to_list_or_dict_types_without_annotation(self): LIST_SIGNAL = create_simple_signal({"list_input": list}) - DICT_SIGNAL = create_simple_signal({"list_input": dict}) + DICT_SIGNAL = create_simple_signal({"dict_input": dict}) LIST_WITHOUT_ANNOTATION_SIGNAL = create_simple_signal({"list_input": List}) + DICT_WITHOUT_ANNOTATION_SIGNAL = create_simple_signal({"dict_input": Dict}) with self.assertRaises(Exception): schema_from_signal(LIST_SIGNAL) @@ -256,6 +257,14 @@ def test_throw_exception_to_list_or_dict_types_without_annotation(self): with self.assertRaises(TypeError): schema_from_signal(LIST_WITHOUT_ANNOTATION_SIGNAL) + with self.assertRaises(TypeError): + schema_from_signal(DICT_WITHOUT_ANNOTATION_SIGNAL) + + def test_throw_exception_invalid_dict_annotation(self): + INVALID_DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, NestedAttrsWithDefaults]}) + with self.assertRaises(TypeError): + schema_from_signal(INVALID_DICT_SIGNAL) + def test_list_with_annotation_works(self): LIST_SIGNAL = create_simple_signal({"list_input": List[int]}) expected_dict = { @@ -270,3 +279,18 @@ def test_list_with_annotation_works(self): } schema = schema_from_signal(LIST_SIGNAL) self.assertDictEqual(schema, expected_dict) + + def test_dict_with_annotation_works(self): + DICT_SIGNAL = create_simple_signal({"dict_input": Dict[str, int]}) + expected_dict = { + 'name': 'CloudEvent', + 'type': 'record', + 'doc': 'Avro Event Format for CloudEvents created with openedx_events/schema', + 'namespace': 'simple.signal', + 'fields': [{ + 'name': 'dict_input', + 'type': {'type': 'map', 'values': 'long'}, + }], + } + schema = schema_from_signal(DICT_SIGNAL) + self.assertDictEqual(schema, expected_dict) diff --git a/openedx_events/event_bus/avro/tests/test_utilities.py b/openedx_events/event_bus/avro/tests/test_utilities.py index 1644c3e4..05b1b9c7 100644 --- a/openedx_events/event_bus/avro/tests/test_utilities.py +++ b/openedx_events/event_bus/avro/tests/test_utilities.py @@ -39,6 +39,13 @@ class SimpleAttrs: string_field: str +@attr.s(auto_attribs=True) +class ComplexAttrs: + """Class with all complex type fields""" + list_field: list[int] + dict_field: dict[str, int] + + @attr.s(auto_attribs=True) class SubTestData0: """Subclass for testing nested attrs""" diff --git a/openedx_events/event_bus/avro/types.py b/openedx_events/event_bus/avro/types.py index f3bc2536..b757a899 100644 --- a/openedx_events/event_bus/avro/types.py +++ b/openedx_events/event_bus/avro/types.py @@ -9,6 +9,6 @@ PYTHON_TYPE_TO_AVRO_MAPPING = { **SIMPLE_PYTHON_TYPE_TO_AVRO_MAPPING, None: "null", - dict: "record", + dict: "map", list: "array", }