diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index f37d4eea..4bdffdf9 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -376,6 +376,10 @@ def get_canonical_name(self): def resolve_datatype(name): + + if not isinstance(name, str): + raise TypeError(f"Input 'name' must be of type 'str', but got type '{type(name).__name__}'") + _special_types = { "BINARY": IntType(1, False), "BIPOLAR": BipolarType(), diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index 1bd0fece..e98a885f 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -29,6 +29,7 @@ import numpy as np from qonnx.core.datatype import DataType +from qonnx.core.datatype import resolve_datatype def test_datatypes(): @@ -97,3 +98,53 @@ def test_smallest_possible(): assert DataType.get_smallest_possible(-1) == DataType["BIPOLAR"] assert DataType.get_smallest_possible(-3) == DataType["INT3"] assert DataType.get_smallest_possible(-3.2) == DataType["FLOAT32"] + + +def test_resolve_datatype(): + assert resolve_datatype("BIPOLAR") + assert resolve_datatype("BINARY") + assert resolve_datatype("TERNARY") + assert resolve_datatype("UINT2") + assert resolve_datatype("UINT3") + assert resolve_datatype("UINT4") + assert resolve_datatype("UINT8") + assert resolve_datatype("UINT16") + assert resolve_datatype("UINT32") + assert resolve_datatype("INT2") + assert resolve_datatype("INT3") + assert resolve_datatype("INT4") + assert resolve_datatype("INT8") + assert resolve_datatype("INT16") + assert resolve_datatype("INT32") + assert resolve_datatype("FLOAT32") + + +def test_input_type_error(): + + def test_resolve_datatype(input): + # test with invalid input to check if the TypeError works + try: + resolve_datatype(input) # This should raise a TypeError + except TypeError as e: + pass + else: + print("Test with invalid input failed: No TypeError was raised.") + + test_resolve_datatype(123) + test_resolve_datatype(1.23) + test_resolve_datatype(DataType["BIPOLAR"]) + test_resolve_datatype(DataType["BINARY"]) + test_resolve_datatype(DataType["TERNARY"]) + test_resolve_datatype(DataType["UINT2"]) + test_resolve_datatype(DataType["UINT3"]) + test_resolve_datatype(DataType["UINT4"]) + test_resolve_datatype(DataType["UINT8"]) + test_resolve_datatype(DataType["UINT16"]) + test_resolve_datatype(DataType["UINT32"]) + test_resolve_datatype(DataType["INT2"]) + test_resolve_datatype(DataType["INT3"]) + test_resolve_datatype(DataType["INT4"]) + test_resolve_datatype(DataType["INT8"]) + test_resolve_datatype(DataType["INT16"]) + test_resolve_datatype(DataType["INT32"]) + test_resolve_datatype(DataType["FLOAT32"])