diff --git a/homeassistant/components/zha/core/channels/__init__.py b/homeassistant/components/zha/core/channels/__init__.py index d60c38c69a6..b63c20e14eb 100644 --- a/homeassistant/components/zha/core/channels/__init__.py +++ b/homeassistant/components/zha/core/channels/__init__.py @@ -2,12 +2,14 @@ from __future__ import annotations import asyncio -from typing import Any +from collections.abc import Coroutine +from typing import TYPE_CHECKING, Any, TypeVar +import zigpy.endpoint import zigpy.zcl.clusters.closures from homeassistant.const import ATTR_DEVICE_ID -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_send from . import ( # noqa: F401 @@ -29,20 +31,25 @@ from .. import ( device as zha_core_device, discovery as zha_disc, registries as zha_regs, - typing as zha_typing, ) -ChannelsDict = dict[str, zha_typing.ChannelType] +if TYPE_CHECKING: + from ...entity import ZhaEntity + from ..device import ZHADevice + +_ChannelsT = TypeVar("_ChannelsT", bound="Channels") +_ChannelPoolT = TypeVar("_ChannelPoolT", bound="ChannelPool") +_ChannelsDictType = dict[str, base.ZigbeeChannel] class Channels: """All discovered channels of a device.""" - def __init__(self, zha_device: zha_typing.ZhaDeviceType) -> None: + def __init__(self, zha_device: ZHADevice) -> None: """Initialize instance.""" - self._pools: list[zha_typing.ChannelPoolType] = [] - self._power_config = None - self._identify = None + self._pools: list[ChannelPool] = [] + self._power_config: base.ZigbeeChannel | None = None + self._identify: base.ZigbeeChannel | None = None self._semaphore = asyncio.Semaphore(3) self._unique_id = str(zha_device.ieee) self._zdo_channel = base.ZDOChannel(zha_device.device.endpoints[0], zha_device) @@ -54,23 +61,23 @@ class Channels: return self._pools @property - def power_configuration_ch(self) -> zha_typing.ChannelType: + def power_configuration_ch(self) -> base.ZigbeeChannel | None: """Return power configuration channel.""" return self._power_config @power_configuration_ch.setter - def power_configuration_ch(self, channel: zha_typing.ChannelType) -> None: + def power_configuration_ch(self, channel: base.ZigbeeChannel) -> None: """Power configuration channel setter.""" if self._power_config is None: self._power_config = channel @property - def identify_ch(self) -> zha_typing.ChannelType: + def identify_ch(self) -> base.ZigbeeChannel | None: """Return power configuration channel.""" return self._identify @identify_ch.setter - def identify_ch(self, channel: zha_typing.ChannelType) -> None: + def identify_ch(self, channel: base.ZigbeeChannel) -> None: """Power configuration channel setter.""" if self._identify is None: self._identify = channel @@ -81,17 +88,17 @@ class Channels: return self._semaphore @property - def zdo_channel(self) -> zha_typing.ZDOChannelType: + def zdo_channel(self) -> base.ZDOChannel: """Return ZDO channel.""" return self._zdo_channel @property - def zha_device(self) -> zha_typing.ZhaDeviceType: + def zha_device(self) -> ZHADevice: """Return parent zha device.""" return self._zha_device @property - def unique_id(self): + def unique_id(self) -> str: """Return the unique id for this channel.""" return self._unique_id @@ -104,7 +111,7 @@ class Channels: } @classmethod - def new(cls, zha_device: zha_typing.ZhaDeviceType) -> Channels: + def new(cls: type[_ChannelsT], zha_device: ZHADevice) -> _ChannelsT: """Create new instance.""" channels = cls(zha_device) for ep_id in sorted(zha_device.device.endpoints): @@ -142,9 +149,9 @@ class Channels: def async_new_entity( self, component: str, - entity_class: zha_typing.CALLABLE_T, + entity_class: type[ZhaEntity], unique_id: str, - channels: list[zha_typing.ChannelType], + channels: list[base.ZigbeeChannel], ): """Signal new entity addition.""" if self.zha_device.status == zha_core_device.DeviceStatus.INITIALIZED: @@ -178,30 +185,30 @@ class ChannelPool: def __init__(self, channels: Channels, ep_id: int) -> None: """Initialize instance.""" - self._all_channels: ChannelsDict = {} - self._channels: Channels = channels - self._claimed_channels: ChannelsDict = {} - self._id: int = ep_id - self._client_channels: dict[str, zha_typing.ClientChannelType] = {} - self._unique_id: str = f"{channels.unique_id}-{ep_id}" + self._all_channels: _ChannelsDictType = {} + self._channels = channels + self._claimed_channels: _ChannelsDictType = {} + self._id = ep_id + self._client_channels: dict[str, base.ClientChannel] = {} + self._unique_id = f"{channels.unique_id}-{ep_id}" @property - def all_channels(self) -> ChannelsDict: + def all_channels(self) -> _ChannelsDictType: """All server channels of an endpoint.""" return self._all_channels @property - def claimed_channels(self) -> ChannelsDict: + def claimed_channels(self) -> _ChannelsDictType: """Channels in use.""" return self._claimed_channels @property - def client_channels(self) -> dict[str, zha_typing.ClientChannelType]: + def client_channels(self) -> dict[str, base.ClientChannel]: """Return a dict of client channels.""" return self._client_channels @property - def endpoint(self) -> zha_typing.ZigpyEndpointType: + def endpoint(self) -> zigpy.endpoint.Endpoint: """Return endpoint of zigpy device.""" return self._channels.zha_device.device.endpoints[self.id] @@ -216,7 +223,7 @@ class ChannelPool: return self._channels.zha_device.nwk @property - def is_mains_powered(self) -> bool: + def is_mains_powered(self) -> bool | None: """Device is_mains_powered.""" return self._channels.zha_device.is_mains_powered @@ -231,7 +238,7 @@ class ChannelPool: return self._channels.zha_device.manufacturer_code @property - def hass(self): + def hass(self) -> HomeAssistant: """Return hass.""" return self._channels.zha_device.hass @@ -246,7 +253,7 @@ class ChannelPool: return self._channels.zha_device.skip_configuration @property - def unique_id(self): + def unique_id(self) -> str: """Return the unique id for this channel.""" return self._unique_id @@ -272,7 +279,7 @@ class ChannelPool: ) @classmethod - def new(cls, channels: Channels, ep_id: int) -> ChannelPool: + def new(cls: type[_ChannelPoolT], channels: Channels, ep_id: int) -> _ChannelPoolT: """Create new channels for an endpoint.""" pool = cls(channels, ep_id) pool.add_all_channels() @@ -330,7 +337,7 @@ class ChannelPool: async def _execute_channel_tasks(self, func_name: str, *args: Any) -> None: """Add a throttled channel task and swallow exceptions.""" - async def _throttle(coro): + async def _throttle(coro: Coroutine[Any, Any, None]) -> None: async with self._channels.semaphore: return await coro @@ -347,9 +354,9 @@ class ChannelPool: def async_new_entity( self, component: str, - entity_class: zha_typing.CALLABLE_T, + entity_class: type[ZhaEntity], unique_id: str, - channels: list[zha_typing.ChannelType], + channels: list[base.ZigbeeChannel], ): """Signal new entity addition.""" self._channels.async_new_entity(component, entity_class, unique_id, channels) @@ -360,12 +367,12 @@ class ChannelPool: self._channels.async_send_signal(signal, *args) @callback - def claim_channels(self, channels: list[zha_typing.ChannelType]) -> None: + def claim_channels(self, channels: list[base.ZigbeeChannel]) -> None: """Claim a channel.""" self.claimed_channels.update({ch.id: ch for ch in channels}) @callback - def unclaimed_channels(self) -> list[zha_typing.ChannelType]: + def unclaimed_channels(self) -> list[base.ZigbeeChannel]: """Return a list of available (unclaimed) channels.""" claimed = set(self.claimed_channels) available = set(self.all_channels)