Skip to content

feat: make SignalGroupDescriptor generic #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
28 changes: 16 additions & 12 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import sys
import warnings
import weakref
from collections.abc import Iterable
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Literal,
Optional,
TypeVar,
Expand Down Expand Up @@ -39,6 +41,7 @@

T = TypeVar("T", bound=type)
S = TypeVar("S")
GroupType = TypeVar("GroupType", bound=SignalGroup)


_EQ_OPERATORS: dict[type, dict[str, EqOperator]] = {}
Expand Down Expand Up @@ -153,10 +156,10 @@ def _psygnal_relocate_info_(self, emission_info: EmissionInfo) -> EmissionInfo:

def _build_dataclass_signal_group(
cls: type,
signal_group_class: type[SignalGroup],
signal_group_class: type[GroupType],
equality_operators: Iterable[tuple[str, EqOperator]] | None = None,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
) -> type[SignalGroup]:
) -> type[GroupType]:
"""Build a SignalGroup with events for each field in a dataclass.

Parameters
Expand Down Expand Up @@ -424,7 +427,7 @@ def _setattr_and_emit_(self: object, name: str, value: Any) -> None:
return _inner(super_setattr) if super_setattr else _inner


class SignalGroupDescriptor:
class SignalGroupDescriptor(Generic[GroupType]):
"""Create a [`psygnal.SignalGroup`][] on first instance attribute access.

This descriptor is designed to be used as a class attribute on a dataclass-like
Expand Down Expand Up @@ -544,12 +547,12 @@ def __init__(
warn_on_no_fields: bool = True,
cache_on_instance: bool = True,
patch_setattr: bool = True,
signal_group_class: type[SignalGroup] | None = None,
signal_group_class: type[GroupType] | None = None,
collect_fields: bool = True,
connect_child_events: bool = True,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
):
grp_cls = signal_group_class or SignalGroup
grp_cls = signal_group_class or cast("type[GroupType]", SignalGroup)
if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)):
raise TypeError( # pragma: no cover
f"'signal_group_class' must be a subclass of SignalGroup, not {grp_cls}"
Expand All @@ -574,11 +577,11 @@ def __init__(
self._patch_setattr = patch_setattr
self._connect_child_events = connect_child_events

self._signal_group_class: type[SignalGroup] = grp_cls
self._signal_group_class: type[GroupType] = grp_cls
self._collect_fields = collect_fields
self._signal_aliases = signal_aliases

self._signal_groups: dict[int, type[SignalGroup]] = {}
self._signal_groups: dict[int, type[GroupType]] = {}

def __set_name__(self, owner: type, name: str) -> None:
"""Called when this descriptor is added to class `owner` as attribute `name`."""
Expand Down Expand Up @@ -618,11 +621,11 @@ def _do_patch_setattr(self, owner: type, with_aliases: bool = True) -> None:
def __get__(self, instance: None, owner: type) -> SignalGroupDescriptor: ...

@overload
def __get__(self, instance: object, owner: type) -> SignalGroup: ...
def __get__(self, instance: object, owner: type) -> GroupType: ...

def __get__(
self, instance: object, owner: type
) -> SignalGroup | SignalGroupDescriptor:
) -> GroupType | SignalGroupDescriptor:
"""Return a SignalGroup instance for `instance`."""
if instance is None:
return self
Expand Down Expand Up @@ -652,15 +655,16 @@ def __get__(
lambda: connect_child_events(instance, recurse=True, _group=grp)
)

return self._instance_map[obj_id]
return cast("GroupType", self._instance_map[obj_id])

def _get_signal_group(self, owner: type) -> type[SignalGroup]:
def _get_signal_group(self, owner: type) -> type[GroupType]:
type_id = id(owner)
if type_id not in self._signal_groups:
self._signal_groups[type_id] = self._create_group(owner)
return self._signal_groups[type_id]

def _create_group(self, owner: type) -> type[SignalGroup]:
def _create_group(self, owner: type) -> type[GroupType]:
# Do not collect fields from owner class, copy the SignalGroup
if not self._collect_fields:
# Do not collect fields from owner class
Group = copy.deepcopy(self._signal_group_class)
Expand Down
Loading