diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py index 4869e5ba5..b035982ee 100644 --- a/openfisca_core/indexed_enums/enum.py +++ b/openfisca_core/indexed_enums/enum.py @@ -67,14 +67,14 @@ def encode( return array # String array - if isinstance(array, numpy.ndarray) and array.dtype.kind in {"U", "S"}: + if array.dtype.kind in {"U", "S"}: array = numpy.select( [array == item.name for item in cls], [item.index for item in cls], ).astype(t.ArrayEnum) # Enum items arrays - elif isinstance(array, numpy.ndarray) and array.dtype.kind == "O": + elif array.dtype.kind == "O": # Ensure we are comparing the comparable. The problem this fixes: # On entering this method "cls" will generally come from # variable.possible_values, while the array values may come from @@ -87,12 +87,18 @@ def encode( # name to check that the values in the array, if non-empty, are of # the right type. if len(array) > 0 and cls.__name__ is array[0].__class__.__name__: - cls = array[0].__class__ + klass = array[0].__class__ + + else: + klass = cls array = numpy.select( - [array == item for item in cls], - [item.index for item in cls], + [array == item for item in klass], + [item.index for item in klass], ).astype(t.ArrayEnum) array = numpy.asarray(array, dtype=t.ArrayEnum) return EnumArray(array, cls) + + +__all__ = ["Enum"] diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index de6f83df4..fef8d5386 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -151,16 +151,26 @@ def _is_an_enum(self, other: object) -> TypeGuard[t.Enum]: if self.possible_values is None: raise NotImplementedError + if other is None: + raise NotImplementedError + return ( not hasattr(other, "__name__") and other.__class__.__name__ is self.possible_values.__name__ ) def _is_an_enum_type(self, other: object) -> TypeGuard[type[t.Enum]]: + name: None | str + if self.possible_values is None: raise NotImplementedError - return ( - hasattr(other, "__name__") - and other.__name__ is self.possible_values.__name__ - ) + if other is None: + raise NotImplementedError + + name = getattr(other, "__name__", None) + + return isinstance(name, str) and name is self.possible_values.__name__ + + +__all__ = ["EnumArray"] diff --git a/openfisca_core/indexed_enums/types.py b/openfisca_core/indexed_enums/types.py index 57657b8c5..08d0c79e4 100644 --- a/openfisca_core/indexed_enums/types.py +++ b/openfisca_core/indexed_enums/types.py @@ -1,12 +1,14 @@ from __future__ import annotations -from openfisca_core.types import Array as Array -from openfisca_core.types import ArrayAny as ArrayAny # noqa: F401 -from openfisca_core.types import ArrayBool as ArrayBool # noqa: F401 -from openfisca_core.types import ArrayBytes as ArrayBytes # noqa: F401 -from openfisca_core.types import ArrayEnum as ArrayEnum -from openfisca_core.types import ArrayInt as ArrayInt # noqa: F401 -from openfisca_core.types import ArrayStr as ArrayStr # noqa: F401 +from openfisca_core.types import ( + Array, + ArrayAny, + ArrayBool, + ArrayBytes, + ArrayEnum, + ArrayInt, + ArrayStr, +) import abc import enum @@ -20,3 +22,16 @@ class Enum(enum.Enum): class EnumArray(Array[ArrayEnum], metaclass=abc.ABCMeta): ... + + +__all__ = [ + "Array", + "ArrayAny", + "ArrayBool", + "ArrayBytes", + "ArrayEnum", + "ArrayInt", + "ArrayStr", + "Enum", + "EnumArray", +]