diff --git a/equinox/_better_abstract.py b/equinox/_better_abstract.py index 447c0ef6..f0c5b8e3 100644 --- a/equinox/_better_abstract.py +++ b/equinox/_better_abstract.py @@ -75,10 +75,6 @@ class ConcreteX(AbstractX): should be called as `ConcreteX(attr2, attr1)`. """ - # We can't just combine `ClassVar[AbstractVar[...]]`. At static checking time we - # fake `AbstractVar` as `ClassVar` to prevent it from appearing in __init__ - # signatures. This means that static type checkers think they see - # `ClassVar[ClassVar[...]]` which is not allowed. class AbstractClassVar(Generic[_T]): """Used to mark an abstract class attribute, along with its type. Used as: ```python @@ -146,7 +142,9 @@ def _process_annotation(annotation): "Stringified abstract annotations are not supported" ) else: - return False, False + is_abstract = False + is_class = annotation.startswith("ClassVar[") + return is_abstract, is_class else: if annotation in (AbstractVar, AbstractClassVar): raise TypeError( @@ -162,108 +160,59 @@ def _process_annotation(annotation): raise TypeError("`AbstractClassVar` can only have a single argument.") is_abstract = True is_class = True + elif get_origin(annotation) is ClassVar: + is_abstract = False + is_class = True else: is_abstract = False is_class = False return is_abstract, is_class -_sentinel = object() - - -# try: -# import beartype -# except ImportError: -# def is_subhint(subhint, superhint) -> bool: -# return True # no checking in this case -# else: -# from beartype.door import is_subhint - - -# TODO: reinstate once https://github.com/beartype/beartype/issues/271 is resolved. -def is_subhint(subhint, superhint) -> bool: - return True - - -def _is_concretisation(sub, super): - if isinstance(sub, str) or isinstance(super, str): - raise NotImplementedError("Stringified abstract annotations are not supported") - elif get_origin(super) is AbstractVar: - if get_origin(sub) in (AbstractVar, AbstractClassVar, ClassVar): - (sub_args,) = get_args(sub) - (sup_args,) = get_args(super) - else: - sub_args = sub - (sup_args,) = get_args(super) - elif get_origin(super) is AbstractClassVar: - if get_origin(sub) in (AbstractClassVar, ClassVar): - (sub_args,) = get_args(sub) - (sup_args,) = get_args(super) - else: - return False - else: - assert False - return is_subhint(sub_args, sup_args) - - class ABCMeta(abc.ABCMeta): def register(cls, subclass): raise ValueError def __new__(mcs, name, bases, namespace, /, **kwargs): cls = super().__new__(mcs, name, bases, namespace, **kwargs) - abstract_vars = dict() - abstract_class_vars = dict() - cls_annotations = cls.__dict__.get("__annotations__", {}) - for attr, group in [ - ("__abstractvars__", abstract_vars), - ("__abstractclassvars__", abstract_class_vars), - ]: - for base in bases: - for name, annotation in base.__dict__.get(attr, dict()).items(): - try: - existing_annotation = group[name] - except KeyError: - pass - else: - if not ( - _is_concretisation(annotation, existing_annotation) - or _is_concretisation(existing_annotation, annotation) - ): + + # We don't try and check that our AbstractVars and AbstractClassVars are + # consistently annotated across `cls` and each element of `bases`. Python just + # doesn't really provide any way of checking that two hints are compatible. + # (Subscripted generics make this complicated!) + + abstract_vars = set() + abstract_class_vars = set() + for kls in reversed(cls.__mro__): + ann = kls.__dict__.get("__annotations__", {}) + for name, annotation in ann.items(): + is_abstract, is_class = _process_annotation(annotation) + if is_abstract: + if is_class: + if name in kls.__dict__: raise TypeError( - "Base classes have mismatched type annotations for " - f"{name}" + f"Abstract class attribute {name} cannot have value" ) - try: - new_annotation = cls_annotations[name] - except KeyError: - pass + abstract_vars.discard(name) + abstract_class_vars.add(name) else: - if not _is_concretisation(new_annotation, annotation): + if name in kls.__dict__: raise TypeError( - "Base class and derived class have mismatched type " - f"annotations for {name}" + f"Abstract attribute {name} cannot have value" ) - # Not just `if name not in namespace`, as `cls.__dict__` may be - # slightly bigger from `__init_subclass__`. - if name not in cls.__dict__ and name not in cls_annotations: - group[name] = annotation - for name, annotation in cls_annotations.items(): - is_abstract, is_class = _process_annotation(annotation) - if is_abstract: - if name in namespace: - if is_class: - raise TypeError( - f"Abstract class attribute {name} cannot have value" - ) - else: - raise TypeError(f"Abstract attribute {name} cannot have value") - if is_class: - abstract_class_vars[name] = annotation + # If it's already an abstract class var, then superfluous to + # also consider it an abstract var. + if name not in abstract_class_vars: + abstract_vars.add(name) else: - abstract_vars[name] = annotation - cls.__abstractvars__ = abstract_vars # pyright: ignore - cls.__abstractclassvars__ = abstract_class_vars # pyright: ignore + abstract_vars.discard(name) # not conditional on `is_class` + if is_class: + abstract_class_vars.discard(name) + for name in kls.__dict__.keys(): + abstract_vars.discard(name) + abstract_class_vars.discard(name) + cls.__abstractvars__ = frozenset(abstract_vars) # pyright: ignore + cls.__abstractclassvars__ = frozenset(abstract_class_vars) # pyright: ignore return cls def __call__(cls, *args, **kwargs): diff --git a/equinox/_module.py b/equinox/_module.py index ae133df3..9d63af4f 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -486,7 +486,8 @@ def __getattribute__(cls, item): @dataclass_transform(field_specifiers=(dataclasses.field, field, static_field)) class _ModuleMeta(abc.ABCMeta): - pass + __abstractvars__: frozenset[str] + __abstractclassvars__: frozenset[str] def _is_special_form(cls): diff --git a/tests/test_abstract.py b/tests/test_abstract.py index 91693f6f..114cbf42 100644 --- a/tests/test_abstract.py +++ b/tests/test_abstract.py @@ -28,147 +28,180 @@ class B(A): B() +def test_abstract_attribute_stringified(): + with pytest.raises(NotImplementedError): + + class A(eqx.Module): + x: "AbstractVar[bool]" + + def test_abstract_attribute(): class A(eqx.Module): x: AbstractVar[bool] + assert A.__abstractvars__ == frozenset({"x"}) + assert A.__abstractclassvars__ == frozenset() + with pytest.raises(TypeError, match="abstract attributes"): A() - # class B(eqx.Module): - # x: "AbstractVar[bool]" - - # with pytest.raises(TypeError, match="abstract attributes"): - # B() - - class C(A): + class B(A): y: int + assert B.__abstractvars__ == frozenset({"x"}) + assert B.__abstractclassvars__ == frozenset() + with pytest.raises(TypeError, match="abstract attributes"): - C(y=2) + B(y=2) - class D(A): + class C(A): x: bool y: str - D(x=True, y="hi") + assert C.__abstractvars__ == frozenset() + assert C.__abstractclassvars__ == frozenset() - class E(A): + C(x=True, y="hi") + + class D(A): y: str x: bool # different order - e = E("hi", True) - assert e.x is True - assert e.y == "hi" + assert D.__abstractvars__ == frozenset() + assert D.__abstractclassvars__ == frozenset() - class F(A): + d = D("hi", True) + assert d.x is True + assert d.y == "hi" + + class E(A): y: str @property def x(self): return True - f = F(y="hi") - assert f.x is True - assert f.y == "hi" + assert E.__abstractvars__ == frozenset() + assert E.__abstractclassvars__ == frozenset() + + e = E(y="hi") + assert e.x is True + assert e.y == "hi" with pytest.raises(TypeError, match="unsubscripted"): - class G(eqx.Module): + class F(eqx.Module): x: AbstractVar - # with pytest.raises(TypeError, match="mismatched type annotations"): - - # class H(A): - # x: str - - class I(A): # noqa: E742 + class G(A): x: AbstractVar[bool] - class J(A): + assert G.__abstractvars__ == frozenset({"x"}) + assert G.__abstractclassvars__ == frozenset() + + class H(A): x: AbstractClassVar[bool] - class K(A): + assert H.__abstractvars__ == frozenset() + assert H.__abstractclassvars__ == frozenset({"x"}) + + class I(A): # noqa: E742 x: bool + assert I.__abstractvars__ == frozenset() + assert I.__abstractclassvars__ == frozenset() + I(True) + with pytest.raises(TypeError, match="cannot have value"): - class L(eqx.Module): + class J(eqx.Module): x: AbstractVar[bool] = True - class M1(A): + class K(A): x = True y: bool = False - class M2(A): + assert K.__abstractvars__ == frozenset() + assert K.__abstractclassvars__ == frozenset() + K() + + class L(A): x = True + assert L.__abstractvars__ == frozenset() + assert L.__abstractclassvars__ == frozenset() + L() + def test_abstract_class_attribute(): class A(eqx.Module): x: AbstractClassVar[bool] + assert A.__abstractvars__ == frozenset() + assert A.__abstractclassvars__ == frozenset({"x"}) + with pytest.raises(TypeError, match="abstract class attributes"): A() - # class B(eqx.Module): - # x: "AbstractClassVar[bool]" - - # with pytest.raises(TypeError, match="abstract class attributes"): - # B() - - class C(A): + class B(A): y: int + assert B.__abstractvars__ == frozenset() + assert B.__abstractclassvars__ == frozenset({"x"}) + with pytest.raises(TypeError, match="abstract class attributes"): - C(y=2) + B(y=2) + + with pytest.raises(TypeError, match="unsubscripted"): - with pytest.raises(TypeError, match="mismatched type annotations"): + class C(eqx.Module): + x: AbstractClassVar - class D(A): - x: bool - y: str + class D(A): + x: AbstractClassVar[bool] - with pytest.raises(TypeError, match="mismatched type annotations"): + assert D.__abstractvars__ == frozenset() + assert D.__abstractclassvars__ == frozenset({"x"}) - class E(A): - y: str - x: bool # different order + class E(A): + x: ClassVar[bool] - with pytest.raises(TypeError, match="unsubscripted"): + assert E.__abstractvars__ == frozenset() + assert E.__abstractclassvars__ == frozenset() - class G(eqx.Module): - x: AbstractClassVar + with pytest.raises(TypeError, match="cannot have value"): - # with pytest.raises(TypeError, match="mismatched type annotations"): + class F(eqx.Module): + x: AbstractClassVar[bool] = True - # class H1(A): - # x: str + class G(A): + x = True + y: bool = False - # with pytest.raises(TypeError, match="mismatched type annotations"): + assert G.__abstractvars__ == frozenset() + assert G.__abstractclassvars__ == frozenset() - # class H2(A): - # x: ClassVar[str] + class H(A): + x = True - # with pytest.raises(TypeError, match="mismatched type annotations"): + assert H.__abstractvars__ == frozenset() + assert H.__abstractclassvars__ == frozenset() - # class I(A): # noqa: E742 - # x: AbstractVar[bool] + class I(A): # noqa: E742 + x: AbstractVar[bool] - class J(A): - x: AbstractClassVar[bool] + assert I.__abstractvars__ == frozenset() + assert I.__abstractclassvars__ == frozenset({"x"}) - class K(A): - x: ClassVar[bool] - with pytest.raises(TypeError, match="cannot have value"): +def test_abstract_multiple_inheritance(): + class A(eqx.Module): + x: AbstractVar[int] - class L(eqx.Module): - x: AbstractClassVar[bool] = True + class B(eqx.Module): + x: int - class M1(A): - x = True - y: bool = False + class C(B, A): + pass - class M2(A): - x = True + C(1)