diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index d2a7b62d09090e..991fccd9f49140 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -9,7 +9,16 @@ from functools import lru_cache import logging import time -from typing import TYPE_CHECKING, Any, Literal, TypedDict +from typing import ( + TYPE_CHECKING, + Any, + Literal, + TypedDict, + TypeVar, + Generic, + Type, + Union, +) import attr from yarl import URL @@ -49,6 +58,9 @@ else: from propcache.api import under_cached_property +_EntryTypeT = TypeVar("_EntryTypeT", "DeviceEntry", "DeletedDeviceEntry") +_EnumT = TypeVar("_EnumT", bound=StrEnum) + _LOGGER = logging.getLogger(__name__) DATA_REGISTRY: HassKey[DeviceRegistry] = HassKey("device_registry") @@ -169,11 +181,11 @@ class _EventDeviceRegistryUpdatedData_Update(TypedDict): changes: dict[str, Any] -type EventDeviceRegistryUpdatedData = ( - _EventDeviceRegistryUpdatedData_Create - | _EventDeviceRegistryUpdatedData_Remove - | _EventDeviceRegistryUpdatedData_Update -) +EventDeviceRegistryUpdatedData = Union[ + _EventDeviceRegistryUpdatedData_Create, + _EventDeviceRegistryUpdatedData_Remove, + _EventDeviceRegistryUpdatedData_Update, +] class DeviceEntryType(StrEnum): @@ -222,7 +234,7 @@ def __init__( f"already registered with {existing_device}" ) - +# Validate a DeviceInfo mapping and classify it as one of: 'link', 'primary', or 'secondary', based on which allowed keys it contains. def _validate_device_info( config_entry: ConfigEntry, device_info: DeviceInfo, @@ -262,7 +274,7 @@ def _validate_device_info( _cached_parse_url = lru_cache(maxsize=512)(URL) """Parse a URL and cache the result.""" - +# Validate `configuration_url` is a supported URL with scheme and host. def _validate_configuration_url(value: Any) -> str | None: """Validate and convert configuration_url.""" if value is None: @@ -276,10 +288,9 @@ def _validate_configuration_url(value: Any) -> str | None: return url_as_str - +# Format the mac address string for entry into dev reg. @lru_cache(maxsize=512) def format_mac(mac: str) -> str: - """Format the mac address string for entry into dev reg.""" to_test = mac if len(to_test) == 17 and to_test.count(":") == 5: @@ -297,23 +308,21 @@ def format_mac(mac: str) -> str: # Not sure how formatted, return original return mac - +# Normalize connections to ensure we can match mac addresses. def _normalize_connections( connections: Iterable[tuple[str, str]], ) -> set[tuple[str, str]]: - """Normalize connections to ensure we can match mac addresses.""" return { (key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value) for key, value in connections } - +# Check connections normalization used as attrs validator. def _normalize_connections_validator( instance: Any, attribute: Any, connections: Iterable[tuple[str, str]], ) -> None: - """Check connections normalization used as attrs validator.""" for key, value in connections: if key == CONNECTION_NETWORK_MAC and format_mac(value) != value: raise ValueError(f"Invalid mac address format: {value}") @@ -352,16 +361,16 @@ class DeviceEntry: _cache: dict[str, Any] = attr.ib(factory=dict, eq=False, init=False) @property + # True when the device has any disabler set (user, integration, config_entry) def disabled(self) -> bool: """Return if entry is disabled.""" return self.disabled_by is not None @property + # Convert the DeviceEntry into a JSON-serializable dict used in events and external presentation. def dict_repr(self) -> dict[str, Any]: """Return a dict representation of the entry.""" - # Convert sets and tuples to lists - # so the JSON serializer does not have to do - # it every time + # Convert sets and tuples to lists so the JSON serializer does not have to do it every time return { "area_id": self.area_id, "configuration_url": self.configuration_url, @@ -390,9 +399,9 @@ def dict_repr(self) -> dict[str, Any]: "via_device_id": self.via_device_id, } + # Return a cached JSON representation of the entry. @under_cached_property def json_repr(self) -> bytes | None: - """Return a cached JSON representation of the entry.""" try: dict_repr = self.dict_repr return json_bytes(dict_repr) @@ -406,9 +415,9 @@ def json_repr(self) -> bytes | None: ) return None + # Return a json fragment for storage. @under_cached_property - def as_storage_fragment(self) -> json_fragment: - """Return a json fragment for storage.""" + def as_storage_fragment(self) -> Any: return json_fragment( json_bytes( { @@ -505,9 +514,9 @@ def to_device_entry( name_by_user=self.name_by_user, ) + # Return a json fragment for storage. @under_cached_property - def as_storage_fragment(self) -> json_fragment: - """Return a json fragment for storage.""" + def as_storage_fragment(self) -> Any: return json_fragment( json_bytes( { @@ -535,10 +544,8 @@ def as_storage_fragment(self) -> json_fragment: ) ) - +# Store entity registry data. class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): - """Store entity registry data.""" - async def _async_migrate_func( # noqa: C901 self, old_major_version: int, @@ -635,10 +642,8 @@ async def _async_migrate_func( # noqa: C901 raise NotImplementedError return old_data - -class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)]( - BaseRegistryItems[_EntryTypeT] -): +# Container for device registry items. +class DeviceRegistryItems(BaseRegistryItems[_EntryTypeT], Generic[_EntryTypeT]): """Container for device registry items, maps device id -> entry. Maintains two additional indexes: @@ -646,23 +651,23 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)]( - (DOMAIN, identifier) -> entry """ + # Initialize the container. def __init__(self) -> None: - """Initialize the container.""" super().__init__() self._connections: dict[tuple[str, str], _EntryTypeT] = {} self._identifiers: dict[tuple[str, str], _EntryTypeT] = {} + # Index an entry. def _index_entry(self, key: str, entry: _EntryTypeT) -> None: - """Index an entry.""" for connection in entry.connections: self._connections[connection] = entry for identifier in entry.identifiers: self._identifiers[identifier] = entry + # Unindex an entry. def _unindex_entry( self, key: str, replacement_entry: _EntryTypeT | None = None ) -> None: - """Unindex an entry.""" old_entry = self.data[key] for connection in old_entry.connections: if connection in self._connections: @@ -671,12 +676,12 @@ def _unindex_entry( if identifier in self._identifiers: del self._identifiers[identifier] + # Get an entry by identifiers or connections. def get_entry( self, identifiers: set[tuple[str, str]] | None = None, connections: set[tuple[str, str]] | None = None, ) -> _EntryTypeT | None: - """Get entry from identifiers or connections.""" if identifiers: for identifier in identifiers: if identifier in self._identifiers: @@ -688,12 +693,12 @@ def get_entry( return self._connections[connection] return None + # Get all entries matching identifiers or connections. def get_entries( self, identifiers: set[tuple[str, str]] | None, connections: set[tuple[str, str]] | None, ) -> Iterable[_EntryTypeT]: - """Get entries from identifiers or connections.""" if identifiers: for identifier in identifiers: if identifier in self._identifiers: @@ -703,9 +708,8 @@ def get_entries( if connection in self._connections: yield self._connections[connection] - +# Container for active device registry items. class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]): - """Container for active (non-deleted) device registry entries.""" def __init__(self) -> None: """Initialize the container. @@ -721,8 +725,8 @@ def __init__(self) -> None: self._config_entry_id_index: RegistryIndexType = defaultdict(dict) self._labels_index: RegistryIndexType = defaultdict(dict) + # Index an entry. def _index_entry(self, key: str, entry: DeviceEntry) -> None: - """Index an entry.""" super()._index_entry(key, entry) if (area_id := entry.area_id) is not None: self._area_id_index[area_id][key] = True @@ -731,10 +735,11 @@ def _index_entry(self, key: str, entry: DeviceEntry) -> None: for config_entry_id in entry.config_entries: self._config_entry_id_index[config_entry_id][key] = True + # Unindex an entry. def _unindex_entry( self, key: str, replacement_entry: DeviceEntry | None = None ) -> None: - """Unindex an entry.""" + # Remove from area/label/config_entry indexes in addition to base indexes. entry = self.data[key] if area_id := entry.area_id: self._unindex_entry_value(key, area_id, self._area_id_index) @@ -745,35 +750,33 @@ def _unindex_entry( self._unindex_entry_value(key, config_entry_id, self._config_entry_id_index) super()._unindex_entry(key, replacement_entry) + # Get devices for area. def get_devices_for_area_id(self, area_id: str) -> list[DeviceEntry]: - """Get devices for area.""" data = self.data return [data[key] for key in self._area_id_index.get(area_id, ())] + # Get devices for label. def get_devices_for_label(self, label: str) -> list[DeviceEntry]: - """Get devices for label.""" data = self.data return [data[key] for key in self._labels_index.get(label, ())] + # Get devices for config entry. def get_devices_for_config_entry_id( self, config_entry_id: str ) -> list[DeviceEntry]: - """Get devices for config entry.""" data = self.data return [ data[key] for key in self._config_entry_id_index.get(config_entry_id, ()) ] - +# Class to hold a registry of devices. class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): - """Class to hold a registry of devices.""" - devices: ActiveDeviceRegistryItems deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] _device_data: dict[str, DeviceEntry] + # Initialize the device registry. def __init__(self, hass: HomeAssistant) -> None: - """Initialize the device registry.""" self.hass = hass self._store = DeviceRegistryStore( hass, @@ -783,15 +786,16 @@ def __init__(self, hass: HomeAssistant) -> None: minor_version=STORAGE_VERSION_MINOR, ) + # Get a device by device id. @callback def async_get(self, device_id: str) -> DeviceEntry | None: - """Get device. - + """ We retrieve the DeviceEntry from the underlying dict to avoid the overhead of the UserDict __getitem__. """ return self._device_data.get(device_id) + # Get a device by identifiers or connections. @callback def async_get_device( self, @@ -801,13 +805,13 @@ def async_get_device( """Check if device is registered.""" return self.devices.get_entry(identifiers, connections) + # Substitute placeholders in entity name. def _substitute_name_placeholders( self, domain: str, name: str, translation_placeholders: Mapping[str, str], ) -> str: - """Substitute placeholders in entity name.""" try: return name.format(**translation_placeholders) except KeyError as err: @@ -829,44 +833,37 @@ def _substitute_name_placeholders( return name @callback - def async_get_or_create( + def _prepare_device_info_and_normalize( self, - *, - config_entry_id: str, - config_subentry_id: str | None | UndefinedType = UNDEFINED, - configuration_url: str | URL | None | UndefinedType = UNDEFINED, - connections: set[tuple[str, str]] | None | UndefinedType = UNDEFINED, - created_at: str | datetime | UndefinedType = UNDEFINED, # will be ignored - default_manufacturer: str | None | UndefinedType = UNDEFINED, - default_model: str | None | UndefinedType = UNDEFINED, - default_name: str | None | UndefinedType = UNDEFINED, - # To disable a device if it gets created, does not affect existing devices - disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, - entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED, - hw_version: str | None | UndefinedType = UNDEFINED, - identifiers: set[tuple[str, str]] | None | UndefinedType = UNDEFINED, - manufacturer: str | None | UndefinedType = UNDEFINED, - model: str | None | UndefinedType = UNDEFINED, - model_id: str | None | UndefinedType = UNDEFINED, - modified_at: str | datetime | UndefinedType = UNDEFINED, # will be ignored - name: str | None | UndefinedType = UNDEFINED, - serial_number: str | None | UndefinedType = UNDEFINED, - suggested_area: str | None | UndefinedType = UNDEFINED, - sw_version: str | None | UndefinedType = UNDEFINED, - translation_key: str | None = None, - translation_placeholders: Mapping[str, str] | None = None, - via_device: tuple[str, str] | None | UndefinedType = UNDEFINED, - ) -> DeviceEntry: - """Get device. Create if it doesn't exist.""" + config_entry, + configuration_url, + connections, + identifiers, + translation_key, + translation_placeholders, + name, + default_manufacturer=UNDEFINED, + default_model=UNDEFINED, + default_name=UNDEFINED, + entry_type=UNDEFINED, + hw_version=UNDEFINED, + manufacturer=UNDEFINED, + model=UNDEFINED, + model_id=UNDEFINED, + serial_number=UNDEFINED, + suggested_area=UNDEFINED, + sw_version=UNDEFINED, + via_device=UNDEFINED, + ): + """Build and validate a DeviceInfo mapping and normalize inputs. + + Returns a tuple: (device_info, device_info_type, connections_set, + identifiers_set, name). Connections and identifiers are converted to + sets and MAC addresses in connections are normalized. + """ if configuration_url is not UNDEFINED: configuration_url = _validate_configuration_url(configuration_url) - config_entry = self.hass.config_entries.async_get_entry(config_entry_id) - if config_entry is None: - raise HomeAssistantError( - f"Can't link device to unknown config entry {config_entry_id}" - ) - if translation_key: full_translation_key = ( f"component.{config_entry.domain}.device.{translation_key}.name" @@ -879,10 +876,7 @@ def async_get_or_create( config_entry.domain, translated_name, translation_placeholders or {} ) - # Reconstruct a DeviceInfo dict from the arguments. - # When we upgrade to Python 3.12, we can change this method to instead - # accept kwargs typed as a DeviceInfo dict (PEP 692) - device_info: DeviceInfo = { # type: ignore[assignment] + device_info = { key: val for key, val in ( ("configuration_url", configuration_url), @@ -908,25 +902,40 @@ def async_get_or_create( device_info_type = _validate_device_info(config_entry, device_info) if identifiers is None or identifiers is UNDEFINED: - identifiers = set() + identifiers_set = set() + else: + identifiers_set = set(identifiers) if connections is None or connections is UNDEFINED: - connections = set() + connections_set = set() else: - connections = _normalize_connections(connections) + connections_set = _normalize_connections(connections) - device = self.devices.get_entry( - identifiers=identifiers, connections=connections - ) + return device_info, device_info_type, connections_set, identifiers_set, name + # Create or restore a device entry. + @callback + def _create_or_restore_device( + self, + config_entry, + config_subentry_id, + connections_set, + identifiers_set, + suggested_area, + disabled_by, + ): + """Return existing device or create/restore a new DeviceEntry. + + Returns (device, is_new) + """ + device = self.devices.get_entry(identifiers=identifiers_set, connections=connections_set) is_new = False if device is None: is_new = True - - deleted_device = self.deleted_devices.get_entry(identifiers, connections) + deleted_device = self.deleted_devices.get_entry(identifiers_set, connections_set) if deleted_device is None: - area_id: str | None = None + area_id = None if ( suggested_area is not None and suggested_area is not UNDEFINED @@ -938,24 +947,100 @@ def async_get_or_create( area = ar.async_get(self.hass).async_get_or_create(suggested_area) area_id = area.id device = DeviceEntry(area_id=area_id) - else: self.deleted_devices.pop(deleted_device.id) device = deleted_device.to_device_entry( config_entry, # Interpret not specifying a subentry as None config_subentry_id if config_subentry_id is not UNDEFINED else None, - connections, - identifiers, + connections_set, + identifiers_set, disabled_by, ) disabled_by = UNDEFINED self.devices[device.id] = device - # If creating a new device, default to the config entry name - if device_info_type == "primary" and (not name or name is UNDEFINED): - name = config_entry.title + return device, is_new + + @callback + def async_get_or_create( + self, + *, + config_entry_id: str, + config_subentry_id: str | None | UndefinedType = UNDEFINED, + configuration_url: str | URL | None | UndefinedType = UNDEFINED, + connections: set[tuple[str, str]] | None | UndefinedType = UNDEFINED, + created_at: str | datetime | UndefinedType = UNDEFINED, # will be ignored + default_manufacturer: str | None | UndefinedType = UNDEFINED, + default_model: str | None | UndefinedType = UNDEFINED, + default_name: str | None | UndefinedType = UNDEFINED, + # To disable a device if it gets created, does not affect existing devices + disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, + entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED, + hw_version: str | None | UndefinedType = UNDEFINED, + identifiers: set[tuple[str, str]] | None | UndefinedType = UNDEFINED, + manufacturer: str | None | UndefinedType = UNDEFINED, + model: str | None | UndefinedType = UNDEFINED, + model_id: str | None | UndefinedType = UNDEFINED, + modified_at: str | datetime | UndefinedType = UNDEFINED, # will be ignored + name: str | None | UndefinedType = UNDEFINED, + serial_number: str | None | UndefinedType = UNDEFINED, + suggested_area: str | None | UndefinedType = UNDEFINED, + sw_version: str | None | UndefinedType = UNDEFINED, + translation_key: str | None = None, + translation_placeholders: Mapping[str, str] | None = None, + via_device: tuple[str, str] | None | UndefinedType = UNDEFINED, + ) -> DeviceEntry: + """Get device. Create if it doesn't exist.""" + config_entry = self.hass.config_entries.async_get_entry(config_entry_id) + if config_entry is None: + raise HomeAssistantError( + f"Can't link device to unknown config entry {config_entry_id}" + ) + + ( + device_info, + device_info_type, + connections_set, + identifiers_set, + name, + ) = self._prepare_device_info_and_normalize( + config_entry, + configuration_url, + connections, + identifiers, + translation_key, + translation_placeholders, + name, + default_manufacturer, + default_model, + default_name, + entry_type, + hw_version, + manufacturer, + model, + model_id, + serial_number, + suggested_area, + sw_version, + via_device, + ) + + device, is_new = self._create_or_restore_device( + config_entry, + config_subentry_id, + connections_set, + identifiers_set, + suggested_area, + disabled_by, + ) + + # If creating a new device, default to the config entry name + if is_new and device_info_type == "primary" and (not name or name is UNDEFINED): + name = config_entry.title + + # Apply defaults when the device does not already have the values if default_manufacturer is not UNDEFINED and device.manufacturer is None: manufacturer = default_manufacturer @@ -991,8 +1076,8 @@ def async_get_or_create( hw_version=hw_version, is_new=is_new, manufacturer=manufacturer, - merge_connections=connections or UNDEFINED, - merge_identifiers=identifiers or UNDEFINED, + merge_connections=connections_set or UNDEFINED, + merge_identifiers=identifiers_set or UNDEFINED, model=model, model_id=model_id, name=name, @@ -1042,6 +1127,11 @@ def _async_update_device( # noqa: C901 sw_version: str | None | UndefinedType = UNDEFINED, via_device_id: str | None | UndefinedType = UNDEFINED, ) -> DeviceEntry | None: + # Core internal routine that applies attribute updates to a DeviceEntry. + # Handles adding/removing config entries, merging/validating identifiers + # and connections, updating scalar attributes and handling persistence + # and event emission. Returns the updated DeviceEntry or None when + # the device was removed as part of the operation. """Private update device attributes. :param add_config_subentry_id: Add the device to a specific subentry of add_config_entry_id @@ -1049,37 +1139,13 @@ def _async_update_device( # noqa: C901 """ old = self.devices[device_id] - new_values: dict[str, Any] = {} # Dict with new key/value pairs - old_values: dict[str, Any] = {} # Dict with old key/value pairs + new_values: dict[str, Any] = {} + old_values: dict[str, Any] = {} config_entries = old.config_entries config_entries_subentries = old.config_entries_subentries - if add_config_entry_id is not UNDEFINED: - if ( - add_config_entry := self.hass.config_entries.async_get_entry( - add_config_entry_id - ) - ) is None: - raise HomeAssistantError( - f"Can't link device to unknown config entry {add_config_entry_id}" - ) - - if add_config_subentry_id is not UNDEFINED: - if add_config_entry_id is UNDEFINED: - raise HomeAssistantError( - "Can't add config subentry without specifying config entry" - ) - if ( - add_config_subentry_id - # mypy says add_config_entry can be None. That's impossible, because we - # raise above if that happens - and add_config_subentry_id not in add_config_entry.subentries # type: ignore[union-attr] - ): - raise HomeAssistantError( - f"Config entry {add_config_entry_id} has no subentry {add_config_subentry_id}" - ) - + # Basic validation of parameters that must be consistent if ( remove_config_subentry_id is not UNDEFINED and remove_config_entry_id is UNDEFINED @@ -1103,100 +1169,41 @@ def _async_update_device( # noqa: C901 "Cannot define both merge_identifiers and new_identifiers" ) + # Process adding a config entry (if provided) if add_config_entry_id is not UNDEFINED: - if add_config_subentry_id is UNDEFINED: - # Interpret not specifying a subentry as None (the main entry) - add_config_subentry_id = None - - primary_entry_id = old.primary_config_entry - if ( - device_info_type == "primary" - and add_config_entry_id != primary_entry_id - ): - if ( - primary_entry_id is None - or not ( - primary_entry := self.hass.config_entries.async_get_entry( - primary_entry_id - ) - ) - or primary_entry.domain in LOW_PRIO_CONFIG_ENTRY_DOMAINS - ): - new_values["primary_config_entry"] = add_config_entry_id - old_values["primary_config_entry"] = primary_entry_id - - if add_config_entry_id not in old.config_entries: - config_entries = old.config_entries | {add_config_entry_id} - config_entries_subentries = old.config_entries_subentries | { - add_config_entry_id: {add_config_subentry_id} - } - # Enable the device if it was disabled by config entry and we're adding - # a non disabled config entry - if ( - # mypy says add_config_entry can be None. That's impossible, because we - # raise above if that happens - not add_config_entry.disabled_by # type: ignore[union-attr] - and old.disabled_by is DeviceEntryDisabler.CONFIG_ENTRY - ): - new_values["disabled_by"] = None - old_values["disabled_by"] = old.disabled_by - elif ( - add_config_subentry_id - not in old.config_entries_subentries[add_config_entry_id] - ): - config_entries_subentries = old.config_entries_subentries | { - add_config_entry_id: old.config_entries_subentries[ - add_config_entry_id - ] - | {add_config_subentry_id} - } + ( + config_entries, + config_entries_subentries, + new_vals_add, + old_vals_add, + ) = self._process_add_config_entry( + old, add_config_entry_id, add_config_subentry_id, device_info_type + ) + new_values.update(new_vals_add) + old_values.update(old_vals_add) + # Process removing a config entry (if provided and present) if ( remove_config_entry_id is not UNDEFINED and remove_config_entry_id in config_entries ): - if remove_config_subentry_id is UNDEFINED: - config_entries_subentries = dict(old.config_entries_subentries) - del config_entries_subentries[remove_config_entry_id] - elif ( - remove_config_subentry_id - in old.config_entries_subentries[remove_config_entry_id] - ): - config_entries_subentries = old.config_entries_subentries | { - remove_config_entry_id: old.config_entries_subentries[ - remove_config_entry_id - ] - - {remove_config_subentry_id} - } - if not config_entries_subentries[remove_config_entry_id]: - del config_entries_subentries[remove_config_entry_id] - - if remove_config_entry_id not in config_entries_subentries: - if config_entries == {remove_config_entry_id}: - self.async_remove_device(device_id) - return None - - if remove_config_entry_id == old.primary_config_entry: - new_values["primary_config_entry"] = None - old_values["primary_config_entry"] = old.primary_config_entry - - config_entries = config_entries - {remove_config_entry_id} - - # Disable the device if it is enabled and all remaining config entries - # are disabled - has_enabled_config_entries = any( - config_entry.disabled_by is None - for config_entry_id in config_entries - if ( - config_entry := self.hass.config_entries.async_get_entry( - config_entry_id - ) - ) - is not None - ) - if not has_enabled_config_entries and old.disabled_by is None: - new_values["disabled_by"] = DeviceEntryDisabler.CONFIG_ENTRY - old_values["disabled_by"] = old.disabled_by + ( + config_entries, + config_entries_subentries, + new_vals_rem, + old_vals_rem, + ) = self._process_remove_config_entry( + old, + remove_config_entry_id, + remove_config_subentry_id, + config_entries, + config_entries_subentries, + ) + # If helper indicates device should be removed + if new_vals_rem is None: + self.async_remove_device(device_id) + new_values.update(new_vals_rem) + old_values.update(old_vals_rem) if config_entries != old.config_entries: new_values["config_entries"] = config_entries @@ -1206,46 +1213,28 @@ def _async_update_device( # noqa: C901 new_values["config_entries_subentries"] = config_entries_subentries old_values["config_entries_subentries"] = old.config_entries_subentries - added_connections: set[tuple[str, str]] | None = None - added_identifiers: set[tuple[str, str]] | None = None - - if merge_connections is not UNDEFINED: - normalized_connections = self._validate_connections( - device_id, - merge_connections, - allow_collisions, - ) - old_connections = old.connections - if not normalized_connections.issubset(old_connections): - added_connections = normalized_connections - new_values["connections"] = old_connections | normalized_connections - old_values["connections"] = old_connections - - if merge_identifiers is not UNDEFINED: - merge_identifiers = self._validate_identifiers( - device_id, merge_identifiers, allow_collisions - ) - old_identifiers = old.identifiers - if not merge_identifiers.issubset(old_identifiers): - added_identifiers = merge_identifiers - new_values["identifiers"] = old_identifiers | merge_identifiers - old_values["identifiers"] = old_identifiers - - if new_connections is not UNDEFINED: - added_connections = new_values["connections"] = self._validate_connections( - device_id, new_connections, False - ) - old_values["connections"] = old.connections - - if new_identifiers is not UNDEFINED: - added_identifiers = new_values["identifiers"] = self._validate_identifiers( - device_id, new_identifiers, False - ) - old_values["identifiers"] = old.identifiers + # Process connections and identifiers via helper + ( + new_conn_ident_vals, + old_conn_ident_vals, + added_connections, + added_identifiers, + ) = self._process_connection_identifier_changes( + old, + device_id, + merge_connections, + merge_identifiers, + new_connections, + new_identifiers, + allow_collisions, + ) + new_values.update(new_conn_ident_vals) + old_values.update(old_conn_ident_vals) if configuration_url is not UNDEFINED: configuration_url = _validate_configuration_url(configuration_url) + # Process simple scalar/set attributes for attr_name, value in ( ("area_id", area_id), ("configuration_url", configuration_url), @@ -1283,9 +1272,7 @@ def _async_update_device( # noqa: C901 new = attr.evolve(old, **new_values) self.devices[device_id] = new - # NOTE: Once we solve the broader issue of duplicated devices, we might - # want to revisit it. Instead of simply removing the duplicated deleted device, - # we might want to merge the information from it into the non-deleted device. + # Remove any deleted devices that are resurrected by added identifiers/connections for deleted_device in self.deleted_devices.get_entries( added_identifiers, added_connections ): @@ -1343,6 +1330,9 @@ def async_update_device( sw_version: str | None | UndefinedType = UNDEFINED, via_device_id: str | None | UndefinedType = UNDEFINED, ) -> DeviceEntry | None: + # Public wrapper around `_async_update_device` that performs deprecation + # reporting for `suggested_area` and forwards arguments to the internal + # implementation. """Update device attributes. :param add_config_subentry_id: Add the device to a specific subentry of add_config_entry_id @@ -1390,6 +1380,9 @@ def _validate_connections( connections: set[tuple[str, str]], allow_collisions: bool, ) -> set[tuple[str, str]]: + # Normalize and optionally check for collisions of connections. + # If allow_collisions is False this will raise DeviceConnectionCollisionError + # when another device already claims a connection tuple. """Normalize and validate connections, raise on collision with other devices.""" normalized_connections = _normalize_connections(connections) if allow_collisions: @@ -1415,6 +1408,7 @@ def _validate_identifiers( identifiers: set[tuple[str, str]], allow_collisions: bool, ) -> set[tuple[str, str]]: + # Validate identifier collisions similar to `_validate_connections`. """Validate identifiers, raise on collision with other devices.""" if allow_collisions: return identifiers @@ -1432,6 +1426,9 @@ def _validate_identifiers( @callback def async_remove_device(self, device_id: str) -> None: + # Delete an active device, move it to the deleted_devices index and + # fire the registry removed event. Also updates any devices that + # referenced the removed device via `via_device_id`. """Remove a device from the device registry.""" self.hass.verify_event_loop_thread("device_registry.async_remove_device") device = self.devices.pop(device_id) @@ -1460,8 +1457,10 @@ def async_remove_device(self, device_id: str) -> None: ) self.async_schedule_save() + # Load the device registry from storage. async def async_load(self) -> None: - """Load the device registry.""" + # Load persisted device registry data from storage and populate + # internal indexes for active and deleted devices. async_setup_cleanup(self.hass, self) data = await self._store.async_load() @@ -1517,8 +1516,8 @@ async def async_load(self) -> None: ) # Introduced in 0.111 - def get_optional_enum[_EnumT: StrEnum]( - cls: type[_EnumT], value: str | None, undefined: bool + def get_optional_enum( + cls: Type[_EnumT], value: str | None, undefined: bool ) -> _EnumT | UndefinedType | None: """Convert string to the passed enum, UNDEFINED or None.""" if undefined: @@ -1559,9 +1558,10 @@ def get_optional_enum[_EnumT: StrEnum]( self.deleted_devices = deleted_devices self._device_data = devices.data + # Return data of the device registry to store in a file. @callback def _data_to_save(self) -> dict[str, Any]: - """Return data of device registry to store in a file.""" + # Prepare the JSON-serializable structure that will be written to disk. return { "devices": [entry.as_storage_fragment for entry in self.devices.values()], "deleted_devices": [ @@ -1569,9 +1569,12 @@ def _data_to_save(self) -> dict[str, Any]: ], } + # Clear config entry from registry devices. @callback def async_clear_config_entry(self, config_entry_id: str) -> None: - """Clear config entry from registry entries.""" + # Remove references to a config entry from all devices. If a deleted + # device loses its last config entry we mark it orphaned and set + # an orphan timestamp for later purging. now_time = time.time() for device in self.devices.get_devices_for_config_entry_id(config_entry_id): self._async_update_device(device.id, remove_config_entry_id=config_entry_id) @@ -1602,10 +1605,13 @@ def async_clear_config_entry(self, config_entry_id: str) -> None: ) self.async_schedule_save() + # Clear config entry from registry entries. @callback def async_clear_config_subentry( self, config_entry_id: str, config_subentry_id: str ) -> None: + # Remove references to a specific subentry of a config entry from + # devices; used when a subflow is removed. """Clear config entry from registry entries.""" now_time = time.time() for device in self.devices.get_devices_for_config_entry_id(config_entry_id): @@ -1655,6 +1661,8 @@ def async_purge_expired_orphaned_devices(self) -> None: We need to purge these periodically to avoid the database growing without bound. """ + # Remove deleted devices which have been orphaned for longer than + # ORPHANED_DEVICE_KEEP_SECONDS to prevent unbounded growth. now_time = time.time() for deleted_device in list(self.deleted_devices.values()): if deleted_device.orphaned_timestamp is None: @@ -1666,9 +1674,9 @@ def async_purge_expired_orphaned_devices(self) -> None: ): del self.deleted_devices[deleted_device.id] + # Clear area id from registry entries. @callback def async_clear_area_id(self, area_id: str) -> None: - """Clear area id from registry entries.""" for device in self.devices.get_devices_for_area_id(area_id): self._async_update_device(device.id, area_id=None) for deleted_device in list(self.deleted_devices.values()): @@ -1679,9 +1687,9 @@ def async_clear_area_id(self, area_id: str) -> None: ) self.async_schedule_save() + # Clear label id from registry entries. @callback def async_clear_label_id(self, label_id: str) -> None: - """Clear label from registry entries.""" for device in self.devices.get_devices_for_label(label_id): self._async_update_device(device.id, labels=device.labels - {label_id}) for deleted_device in list(self.deleted_devices.values()): @@ -1692,39 +1700,39 @@ def async_clear_label_id(self, label_id: str) -> None: ) self.async_schedule_save() - +# Get the device registry. @callback @singleton(DATA_REGISTRY) def async_get(hass: HomeAssistant) -> DeviceRegistry: - """Get device registry.""" + # Return the singleton DeviceRegistry instance for this HomeAssistant + # instance. Uses the @singleton decorator to ensure a single registry. return DeviceRegistry(hass) - +# Load the device registry. async def async_load(hass: HomeAssistant) -> None: - """Load device registry.""" assert DATA_REGISTRY not in hass.data await async_get(hass).async_load() - +# Return entries for area. @callback def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> list[DeviceEntry]: - """Return entries that match an area.""" + # Convenience helper: return active devices assigned to the given area id. return registry.devices.get_devices_for_area_id(area_id) - +# Return entries for label. @callback def async_entries_for_label( registry: DeviceRegistry, label_id: str ) -> list[DeviceEntry]: - """Return entries that match a label.""" + # Convenience helper: return active devices that have the given label. return registry.devices.get_devices_for_label(label_id) - +# Return entries for config entry. @callback def async_entries_for_config_entry( registry: DeviceRegistry, config_entry_id: str ) -> list[DeviceEntry]: - """Return entries that match a config entry.""" + # Convenience helper: return active devices linked to the given config entry. return registry.devices.get_devices_for_config_entry_id(config_entry_id) @@ -1768,14 +1776,13 @@ def async_config_entry_disabled_by_changed( device.id, disabled_by=DeviceEntryDisabler.CONFIG_ENTRY ) - +# Cleanup device registry @callback def async_cleanup( hass: HomeAssistant, dev_reg: DeviceRegistry, ent_reg: entity_registry.EntityRegistry, ) -> None: - """Clean up device registry.""" # Find all devices that are referenced by a config_entry. config_entry_ids = set(hass.config_entries.async_entry_ids()) references_config_entries = { @@ -1810,22 +1817,21 @@ def async_cleanup( # growing without bounds when there are lots of deleted devices dev_reg.async_purge_expired_orphaned_devices() - +# Clean up device registry when entities removed. @callback def async_setup_cleanup(hass: HomeAssistant, dev_reg: DeviceRegistry) -> None: - """Clean up device registry when entities removed.""" from . import entity_registry, label_registry as lr # noqa: PLC0415 + # Filter all except for the remove action from label registry events. @callback def _label_removed_from_registry_filter( event_data: lr.EventLabelRegistryUpdatedData, ) -> bool: - """Filter all except for the remove action from label registry events.""" return event_data["action"] == "remove" + # Update devices that have a label that has been removed. @callback def _handle_label_registry_update(event: lr.EventLabelRegistryUpdated) -> None: - """Update devices that have a label that has been removed.""" dev_reg.async_clear_label_id(event.data["label_id"]) hass.bus.async_listen( @@ -1834,9 +1840,9 @@ def _handle_label_registry_update(event: lr.EventLabelRegistryUpdated) -> None: listener=_handle_label_registry_update, ) + # Cleanup @callback def _async_cleanup() -> None: - """Cleanup.""" ent_reg = entity_registry.async_get(hass) async_cleanup(hass, dev_reg, ent_reg) @@ -1844,18 +1850,18 @@ def _async_cleanup() -> None: hass, _LOGGER, cooldown=CLEANUP_DELAY, immediate=False, function=_async_cleanup ) + # Handle entity updated or removed dispatch. @callback def _async_entity_registry_changed( event: Event[entity_registry.EventEntityRegistryUpdatedData], ) -> None: - """Handle entity updated or removed dispatch.""" debounced_cleanup.async_schedule_call() + # Handle entity updated or removed filter @callback def entity_registry_changed_filter( event_data: entity_registry.EventEntityRegistryUpdatedData, ) -> bool: - """Handle entity updated or removed filter.""" if ( event_data["action"] == "update" and "device_id" not in event_data["changes"] @@ -1864,8 +1870,8 @@ def entity_registry_changed_filter( return True + # Listen for entity registry changes. def _async_listen_for_cleanup() -> None: - """Listen for entity registry changes.""" hass.bus.async_listen( entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_changed, @@ -1883,9 +1889,9 @@ async def startup_clean(event: Event) -> None: hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean) + # Cancel cleanup on shutdown @callback def _on_homeassistant_stop(event: Event) -> None: - """Cancel debounced cleanup.""" debounced_cleanup.async_cancel() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _on_homeassistant_stop)