diff --git a/lib/dl_pydantic/dl_pydantic/typed.py b/lib/dl_pydantic/dl_pydantic/typed.py index 06e41fa00..00f8a659d 100644 --- a/lib/dl_pydantic/dl_pydantic/typed.py +++ b/lib/dl_pydantic/dl_pydantic/typed.py @@ -25,6 +25,7 @@ class TypedMeta(pydantic_model_construction.ModelMetaclass): def __init__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]): cls._classes: dict[str, type["TypedBaseModel"]] = {} + cls._unknown_class: type["TypedBaseModel"] | None = None class TypedBaseModel(base.BaseModel, metaclass=TypedMeta): @@ -37,13 +38,24 @@ class TypedBaseModel(base.BaseModel, metaclass=TypedMeta): @classmethod def register(cls, name: str, class_: Type) -> None: # noqa: UP006 if name in cls._classes: - raise ValueError(f"Class with name '{name}' already registered") + raise ValueError(f"{cls.__name__}(type={name}) already registered") if not issubclass(class_, cls): raise ValueError(f"Class '{class_}' must be subclass of '{cls}'") cls._classes[name] = class_ - LOGGER.info(f"Registered class '{name}' as '{class_}'") + LOGGER.debug("Registered %s(type=%s): %s", cls.__name__, name, class_) + + @classmethod + def register_unknown(cls, class_: Type) -> None: # noqa: UP006 + if cls._unknown_class is not None: + raise ValueError("Unknown class already registered") + + if not issubclass(class_, cls): + raise ValueError(f"Class '{class_}' must be subclass of {cls}") + + cls._unknown_class = class_ + LOGGER.debug("Registered unknown for %s: %s", cls.__name__, class_) @classmethod def _prepare_data(cls, data: dict[str, Any]) -> dict[str, Any]: @@ -66,9 +78,12 @@ def factory(cls, data: Any) -> Self: raise ValueError("Data must be dict") class_name = cls._get_class_name(data) - if class_name not in cls._classes: + if class_name in cls._classes: + class_ = cls._classes[class_name] + elif cls._unknown_class is not None: + class_ = cls._unknown_class + else: raise ValueError(f"Unknown type: {class_name}") - class_ = cls._classes[class_name] data = class_._prepare_data(data) diff --git a/lib/dl_pydantic/dl_pydantic_tests/unit/test_typed.py b/lib/dl_pydantic/dl_pydantic_tests/unit/test_typed.py index 07b107085..8f783f9c4 100644 --- a/lib/dl_pydantic/dl_pydantic_tests/unit/test_typed.py +++ b/lib/dl_pydantic/dl_pydantic_tests/unit/test_typed.py @@ -378,3 +378,39 @@ class Root(dl_pydantic.BaseModel): root = Root.model_validate({"children": {"child": {}}}) assert isinstance(root.children["child"], Child) + + +def test_register_unknown() -> None: + class Base(dl_pydantic.TypedBaseModel): + ... + + class Child(Base): + ... + + Base.register("child", Child) + + with pytest.raises(ValueError): + Base.factory({"type": "bebebe"}) + + class UnknownChild(Base): + type: str = "unknown" + raw_data: typing.Any + + @pydantic.model_validator(mode="before") + @classmethod + def transform_to_raw_data(cls, data: typing.Any) -> typing.Any: + return {"raw_data": data} + + Base.register_unknown(UnknownChild) + + raw_data = {"type": "bebebe"} + data = Base.factory(raw_data) + + assert isinstance(data, UnknownChild) + assert data.type == "unknown" + assert data.raw_data == raw_data + + raw_data = {"type": "child"} + data = Base.factory(raw_data) + + assert isinstance(data, Child)