Skip to content

Commit

Permalink
Fixed Abstract{Class,}Var misbehaving around multiple inheritance.
Browse files Browse the repository at this point in the history
The new implementation is also a fair bit simpler.
  • Loading branch information
patrick-kidger committed Oct 5, 2023
1 parent cea6f3d commit 5d367fa
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 160 deletions.
127 changes: 38 additions & 89 deletions equinox/_better_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ class ConcreteX(AbstractX):
should be called as `ConcreteX(attr2, attr1)`.
"""

# We can't just combine `ClassVar[AbstractVar[...]]`. At static checking time we
# fake `AbstractVar` as `ClassVar` to prevent it from appearing in __init__
# signatures. This means that static type checkers think they see
# `ClassVar[ClassVar[...]]` which is not allowed.
class AbstractClassVar(Generic[_T]):
"""Used to mark an abstract class attribute, along with its type. Used as:
```python
Expand Down Expand Up @@ -146,7 +142,9 @@ def _process_annotation(annotation):
"Stringified abstract annotations are not supported"
)
else:
return False, False
is_abstract = False
is_class = annotation.startswith("ClassVar[")
return is_abstract, is_class
else:
if annotation in (AbstractVar, AbstractClassVar):
raise TypeError(
Expand All @@ -162,108 +160,59 @@ def _process_annotation(annotation):
raise TypeError("`AbstractClassVar` can only have a single argument.")
is_abstract = True
is_class = True
elif get_origin(annotation) is ClassVar:
is_abstract = False
is_class = True
else:
is_abstract = False
is_class = False
return is_abstract, is_class


_sentinel = object()


# try:
# import beartype
# except ImportError:
# def is_subhint(subhint, superhint) -> bool:
# return True # no checking in this case
# else:
# from beartype.door import is_subhint


# TODO: reinstate once https://github.com/beartype/beartype/issues/271 is resolved.
def is_subhint(subhint, superhint) -> bool:
return True


def _is_concretisation(sub, super):
if isinstance(sub, str) or isinstance(super, str):
raise NotImplementedError("Stringified abstract annotations are not supported")
elif get_origin(super) is AbstractVar:
if get_origin(sub) in (AbstractVar, AbstractClassVar, ClassVar):
(sub_args,) = get_args(sub)
(sup_args,) = get_args(super)
else:
sub_args = sub
(sup_args,) = get_args(super)
elif get_origin(super) is AbstractClassVar:
if get_origin(sub) in (AbstractClassVar, ClassVar):
(sub_args,) = get_args(sub)
(sup_args,) = get_args(super)
else:
return False
else:
assert False
return is_subhint(sub_args, sup_args)


class ABCMeta(abc.ABCMeta):
def register(cls, subclass):
raise ValueError

def __new__(mcs, name, bases, namespace, /, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
abstract_vars = dict()
abstract_class_vars = dict()
cls_annotations = cls.__dict__.get("__annotations__", {})
for attr, group in [
("__abstractvars__", abstract_vars),
("__abstractclassvars__", abstract_class_vars),
]:
for base in bases:
for name, annotation in base.__dict__.get(attr, dict()).items():
try:
existing_annotation = group[name]
except KeyError:
pass
else:
if not (
_is_concretisation(annotation, existing_annotation)
or _is_concretisation(existing_annotation, annotation)
):

# We don't try and check that our AbstractVars and AbstractClassVars are
# consistently annotated across `cls` and each element of `bases`. Python just
# doesn't really provide any way of checking that two hints are compatible.
# (Subscripted generics make this complicated!)

abstract_vars = set()
abstract_class_vars = set()
for kls in reversed(cls.__mro__):
ann = kls.__dict__.get("__annotations__", {})
for name, annotation in ann.items():
is_abstract, is_class = _process_annotation(annotation)
if is_abstract:
if is_class:
if name in kls.__dict__:
raise TypeError(
"Base classes have mismatched type annotations for "
f"{name}"
f"Abstract class attribute {name} cannot have value"
)
try:
new_annotation = cls_annotations[name]
except KeyError:
pass
abstract_vars.discard(name)
abstract_class_vars.add(name)
else:
if not _is_concretisation(new_annotation, annotation):
if name in kls.__dict__:
raise TypeError(
"Base class and derived class have mismatched type "
f"annotations for {name}"
f"Abstract attribute {name} cannot have value"
)
# Not just `if name not in namespace`, as `cls.__dict__` may be
# slightly bigger from `__init_subclass__`.
if name not in cls.__dict__ and name not in cls_annotations:
group[name] = annotation
for name, annotation in cls_annotations.items():
is_abstract, is_class = _process_annotation(annotation)
if is_abstract:
if name in namespace:
if is_class:
raise TypeError(
f"Abstract class attribute {name} cannot have value"
)
else:
raise TypeError(f"Abstract attribute {name} cannot have value")
if is_class:
abstract_class_vars[name] = annotation
# If it's already an abstract class var, then superfluous to
# also consider it an abstract var.
if name not in abstract_class_vars:
abstract_vars.add(name)
else:
abstract_vars[name] = annotation
cls.__abstractvars__ = abstract_vars # pyright: ignore
cls.__abstractclassvars__ = abstract_class_vars # pyright: ignore
abstract_vars.discard(name) # not conditional on `is_class`
if is_class:
abstract_class_vars.discard(name)
for name in kls.__dict__.keys():
abstract_vars.discard(name)
abstract_class_vars.discard(name)
cls.__abstractvars__ = frozenset(abstract_vars) # pyright: ignore
cls.__abstractclassvars__ = frozenset(abstract_class_vars) # pyright: ignore
return cls

def __call__(cls, *args, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ def __getattribute__(cls, item):

@dataclass_transform(field_specifiers=(dataclasses.field, field, static_field))
class _ModuleMeta(abc.ABCMeta):
pass
__abstractvars__: frozenset[str]
__abstractclassvars__: frozenset[str]


def _is_special_form(cls):
Expand Down
Loading

0 comments on commit 5d367fa

Please sign in to comment.