diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 6c600bf93d6..64f7b24ff99 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -11,7 +11,7 @@ import logging import os import time import traceback -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Union from serial import SerialException from zigpy.application import ControllerApplication @@ -31,6 +31,7 @@ from homeassistant.helpers.device_registry import ( async_get_registry as get_dev_reg, ) from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_registry import ( EntityRegistry, async_entries_for_device, @@ -96,16 +97,22 @@ if TYPE_CHECKING: from logging import Filter, LogRecord from ..entity import ZhaEntity + from .channels.base import ZigbeeChannel from .store import ZhaStorage _LogFilterType = Union[Filter, Callable[[LogRecord], int]] _LOGGER = logging.getLogger(__name__) -EntityReference = collections.namedtuple( - "EntityReference", - "reference_id zha_device cluster_channels device_info remove_future", -) + +class EntityReference(NamedTuple): + """Describes an entity reference.""" + + reference_id: str + zha_device: ZHADevice + cluster_channels: dict[str, ZigbeeChannel] + device_info: DeviceInfo + remove_future: asyncio.Future[Any] class DevicePairingStatus(Enum): @@ -362,7 +369,7 @@ class ZHAGateway: self, device: ZHADevice, entity_refs: list[EntityReference] | None ) -> None: if entity_refs is not None: - remove_tasks = [] + remove_tasks: list[asyncio.Future[Any]] = [] for entity_ref in entity_refs: remove_tasks.append(entity_ref.remove_future) if remove_tasks: @@ -413,13 +420,14 @@ class ZHAGateway: ): if entity_id == entity_reference.reference_id: return entity_reference + return None def remove_entity_reference(self, entity: ZhaEntity) -> None: """Remove entity reference for given entity_id if found.""" if entity.zha_device.ieee in self.device_registry: entity_refs = self.device_registry.get(entity.zha_device.ieee) self.device_registry[entity.zha_device.ieee] = [ - e for e in entity_refs if e.reference_id != entity.entity_id + e for e in entity_refs if e.reference_id != entity.entity_id # type: ignore[union-attr] ] def _cleanup_group_entity_registry_entries( @@ -470,12 +478,12 @@ class ZHAGateway: def register_entity_reference( self, - ieee, - reference_id, - zha_device, - cluster_channels, - device_info, - remove_future, + ieee: EUI64, + reference_id: str, + zha_device: ZHADevice, + cluster_channels: dict[str, ZigbeeChannel], + device_info: DeviceInfo, + remove_future: asyncio.Future[Any], ): """Record the creation of a hass entity associated with ieee.""" self._device_registry[ieee].append( diff --git a/homeassistant/components/zha/entity.py b/homeassistant/components/zha/entity.py index 0b7f95efb64..4a9b0f7577c 100644 --- a/homeassistant/components/zha/entity.py +++ b/homeassistant/components/zha/entity.py @@ -2,10 +2,9 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable import functools import logging -from typing import Any +from typing import TYPE_CHECKING, Any from homeassistant.const import ATTR_NAME from homeassistant.core import CALLBACK_TYPE, Event, callback @@ -32,6 +31,10 @@ from .core.const import ( from .core.helpers import LogMixin from .core.typing import CALLABLE_T, ChannelType, ZhaDeviceType +if TYPE_CHECKING: + from .core.channels.base import ZigbeeChannel + from .core.device import ZHADevice + _LOGGER = logging.getLogger(__name__) ENTITY_SUFFIX = "entity_suffix" @@ -43,7 +46,7 @@ class BaseZhaEntity(LogMixin, entity.Entity): unique_id_suffix: str | None = None - def __init__(self, unique_id: str, zha_device: ZhaDeviceType, **kwargs) -> None: + def __init__(self, unique_id: str, zha_device: ZHADevice, **kwargs: Any) -> None: """Init ZHA entity.""" self._name: str = "" self._force_update: bool = False @@ -53,9 +56,9 @@ class BaseZhaEntity(LogMixin, entity.Entity): self._unique_id += f"-{self.unique_id_suffix}" self._state: Any = None self._extra_state_attributes: dict[str, Any] = {} - self._zha_device: ZhaDeviceType = zha_device + self._zha_device = zha_device self._unsubs: list[CALLABLE_T] = [] - self.remove_future: Awaitable[None] = None + self.remove_future: asyncio.Future[Any] = asyncio.Future() @property def name(self) -> str: @@ -68,7 +71,7 @@ class BaseZhaEntity(LogMixin, entity.Entity): return self._unique_id @property - def zha_device(self) -> ZhaDeviceType: + def zha_device(self) -> ZHADevice: """Return the zha device this entity is attached to.""" return self._zha_device @@ -159,9 +162,9 @@ class ZhaEntity(BaseZhaEntity, RestoreEntity): def __init__( self, unique_id: str, - zha_device: ZhaDeviceType, - channels: list[ChannelType], - **kwargs, + zha_device: ZHADevice, + channels: list[ZigbeeChannel], + **kwargs: Any, ) -> None: """Init ZHA entity.""" super().__init__(unique_id, zha_device, **kwargs) @@ -170,7 +173,7 @@ class ZhaEntity(BaseZhaEntity, RestoreEntity): self._name: str = f"{zha_device.name} {ieeetail} {ch_names}" if self.unique_id_suffix: self._name += f" {self.unique_id_suffix}" - self.cluster_channels: dict[str, ChannelType] = {} + self.cluster_channels: dict[str, ZigbeeChannel] = {} for channel in channels: self.cluster_channels[channel.name] = channel