Add zha typing [core.channels] (#68377)

This commit is contained in:
Marc Mueller 2022-03-22 15:14:35 +01:00 committed by GitHub
parent bdc92271f2
commit df05e8b950
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,12 +2,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import Any from collections.abc import Coroutine
from typing import TYPE_CHECKING, Any, TypeVar
import zigpy.endpoint
import zigpy.zcl.clusters.closures import zigpy.zcl.clusters.closures
from homeassistant.const import ATTR_DEVICE_ID from homeassistant.const import ATTR_DEVICE_ID
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from . import ( # noqa: F401 from . import ( # noqa: F401
@ -29,20 +31,25 @@ from .. import (
device as zha_core_device, device as zha_core_device,
discovery as zha_disc, discovery as zha_disc,
registries as zha_regs, registries as zha_regs,
typing as zha_typing,
) )
ChannelsDict = dict[str, zha_typing.ChannelType] if TYPE_CHECKING:
from ...entity import ZhaEntity
from ..device import ZHADevice
_ChannelsT = TypeVar("_ChannelsT", bound="Channels")
_ChannelPoolT = TypeVar("_ChannelPoolT", bound="ChannelPool")
_ChannelsDictType = dict[str, base.ZigbeeChannel]
class Channels: class Channels:
"""All discovered channels of a device.""" """All discovered channels of a device."""
def __init__(self, zha_device: zha_typing.ZhaDeviceType) -> None: def __init__(self, zha_device: ZHADevice) -> None:
"""Initialize instance.""" """Initialize instance."""
self._pools: list[zha_typing.ChannelPoolType] = [] self._pools: list[ChannelPool] = []
self._power_config = None self._power_config: base.ZigbeeChannel | None = None
self._identify = None self._identify: base.ZigbeeChannel | None = None
self._semaphore = asyncio.Semaphore(3) self._semaphore = asyncio.Semaphore(3)
self._unique_id = str(zha_device.ieee) self._unique_id = str(zha_device.ieee)
self._zdo_channel = base.ZDOChannel(zha_device.device.endpoints[0], zha_device) self._zdo_channel = base.ZDOChannel(zha_device.device.endpoints[0], zha_device)
@ -54,23 +61,23 @@ class Channels:
return self._pools return self._pools
@property @property
def power_configuration_ch(self) -> zha_typing.ChannelType: def power_configuration_ch(self) -> base.ZigbeeChannel | None:
"""Return power configuration channel.""" """Return power configuration channel."""
return self._power_config return self._power_config
@power_configuration_ch.setter @power_configuration_ch.setter
def power_configuration_ch(self, channel: zha_typing.ChannelType) -> None: def power_configuration_ch(self, channel: base.ZigbeeChannel) -> None:
"""Power configuration channel setter.""" """Power configuration channel setter."""
if self._power_config is None: if self._power_config is None:
self._power_config = channel self._power_config = channel
@property @property
def identify_ch(self) -> zha_typing.ChannelType: def identify_ch(self) -> base.ZigbeeChannel | None:
"""Return power configuration channel.""" """Return power configuration channel."""
return self._identify return self._identify
@identify_ch.setter @identify_ch.setter
def identify_ch(self, channel: zha_typing.ChannelType) -> None: def identify_ch(self, channel: base.ZigbeeChannel) -> None:
"""Power configuration channel setter.""" """Power configuration channel setter."""
if self._identify is None: if self._identify is None:
self._identify = channel self._identify = channel
@ -81,17 +88,17 @@ class Channels:
return self._semaphore return self._semaphore
@property @property
def zdo_channel(self) -> zha_typing.ZDOChannelType: def zdo_channel(self) -> base.ZDOChannel:
"""Return ZDO channel.""" """Return ZDO channel."""
return self._zdo_channel return self._zdo_channel
@property @property
def zha_device(self) -> zha_typing.ZhaDeviceType: def zha_device(self) -> ZHADevice:
"""Return parent zha device.""" """Return parent zha device."""
return self._zha_device return self._zha_device
@property @property
def unique_id(self): def unique_id(self) -> str:
"""Return the unique id for this channel.""" """Return the unique id for this channel."""
return self._unique_id return self._unique_id
@ -104,7 +111,7 @@ class Channels:
} }
@classmethod @classmethod
def new(cls, zha_device: zha_typing.ZhaDeviceType) -> Channels: def new(cls: type[_ChannelsT], zha_device: ZHADevice) -> _ChannelsT:
"""Create new instance.""" """Create new instance."""
channels = cls(zha_device) channels = cls(zha_device)
for ep_id in sorted(zha_device.device.endpoints): for ep_id in sorted(zha_device.device.endpoints):
@ -142,9 +149,9 @@ class Channels:
def async_new_entity( def async_new_entity(
self, self,
component: str, component: str,
entity_class: zha_typing.CALLABLE_T, entity_class: type[ZhaEntity],
unique_id: str, unique_id: str,
channels: list[zha_typing.ChannelType], channels: list[base.ZigbeeChannel],
): ):
"""Signal new entity addition.""" """Signal new entity addition."""
if self.zha_device.status == zha_core_device.DeviceStatus.INITIALIZED: if self.zha_device.status == zha_core_device.DeviceStatus.INITIALIZED:
@ -178,30 +185,30 @@ class ChannelPool:
def __init__(self, channels: Channels, ep_id: int) -> None: def __init__(self, channels: Channels, ep_id: int) -> None:
"""Initialize instance.""" """Initialize instance."""
self._all_channels: ChannelsDict = {} self._all_channels: _ChannelsDictType = {}
self._channels: Channels = channels self._channels = channels
self._claimed_channels: ChannelsDict = {} self._claimed_channels: _ChannelsDictType = {}
self._id: int = ep_id self._id = ep_id
self._client_channels: dict[str, zha_typing.ClientChannelType] = {} self._client_channels: dict[str, base.ClientChannel] = {}
self._unique_id: str = f"{channels.unique_id}-{ep_id}" self._unique_id = f"{channels.unique_id}-{ep_id}"
@property @property
def all_channels(self) -> ChannelsDict: def all_channels(self) -> _ChannelsDictType:
"""All server channels of an endpoint.""" """All server channels of an endpoint."""
return self._all_channels return self._all_channels
@property @property
def claimed_channels(self) -> ChannelsDict: def claimed_channels(self) -> _ChannelsDictType:
"""Channels in use.""" """Channels in use."""
return self._claimed_channels return self._claimed_channels
@property @property
def client_channels(self) -> dict[str, zha_typing.ClientChannelType]: def client_channels(self) -> dict[str, base.ClientChannel]:
"""Return a dict of client channels.""" """Return a dict of client channels."""
return self._client_channels return self._client_channels
@property @property
def endpoint(self) -> zha_typing.ZigpyEndpointType: def endpoint(self) -> zigpy.endpoint.Endpoint:
"""Return endpoint of zigpy device.""" """Return endpoint of zigpy device."""
return self._channels.zha_device.device.endpoints[self.id] return self._channels.zha_device.device.endpoints[self.id]
@ -216,7 +223,7 @@ class ChannelPool:
return self._channels.zha_device.nwk return self._channels.zha_device.nwk
@property @property
def is_mains_powered(self) -> bool: def is_mains_powered(self) -> bool | None:
"""Device is_mains_powered.""" """Device is_mains_powered."""
return self._channels.zha_device.is_mains_powered return self._channels.zha_device.is_mains_powered
@ -231,7 +238,7 @@ class ChannelPool:
return self._channels.zha_device.manufacturer_code return self._channels.zha_device.manufacturer_code
@property @property
def hass(self): def hass(self) -> HomeAssistant:
"""Return hass.""" """Return hass."""
return self._channels.zha_device.hass return self._channels.zha_device.hass
@ -246,7 +253,7 @@ class ChannelPool:
return self._channels.zha_device.skip_configuration return self._channels.zha_device.skip_configuration
@property @property
def unique_id(self): def unique_id(self) -> str:
"""Return the unique id for this channel.""" """Return the unique id for this channel."""
return self._unique_id return self._unique_id
@ -272,7 +279,7 @@ class ChannelPool:
) )
@classmethod @classmethod
def new(cls, channels: Channels, ep_id: int) -> ChannelPool: def new(cls: type[_ChannelPoolT], channels: Channels, ep_id: int) -> _ChannelPoolT:
"""Create new channels for an endpoint.""" """Create new channels for an endpoint."""
pool = cls(channels, ep_id) pool = cls(channels, ep_id)
pool.add_all_channels() pool.add_all_channels()
@ -330,7 +337,7 @@ class ChannelPool:
async def _execute_channel_tasks(self, func_name: str, *args: Any) -> None: async def _execute_channel_tasks(self, func_name: str, *args: Any) -> None:
"""Add a throttled channel task and swallow exceptions.""" """Add a throttled channel task and swallow exceptions."""
async def _throttle(coro): async def _throttle(coro: Coroutine[Any, Any, None]) -> None:
async with self._channels.semaphore: async with self._channels.semaphore:
return await coro return await coro
@ -347,9 +354,9 @@ class ChannelPool:
def async_new_entity( def async_new_entity(
self, self,
component: str, component: str,
entity_class: zha_typing.CALLABLE_T, entity_class: type[ZhaEntity],
unique_id: str, unique_id: str,
channels: list[zha_typing.ChannelType], channels: list[base.ZigbeeChannel],
): ):
"""Signal new entity addition.""" """Signal new entity addition."""
self._channels.async_new_entity(component, entity_class, unique_id, channels) self._channels.async_new_entity(component, entity_class, unique_id, channels)
@ -360,12 +367,12 @@ class ChannelPool:
self._channels.async_send_signal(signal, *args) self._channels.async_send_signal(signal, *args)
@callback @callback
def claim_channels(self, channels: list[zha_typing.ChannelType]) -> None: def claim_channels(self, channels: list[base.ZigbeeChannel]) -> None:
"""Claim a channel.""" """Claim a channel."""
self.claimed_channels.update({ch.id: ch for ch in channels}) self.claimed_channels.update({ch.id: ch for ch in channels})
@callback @callback
def unclaimed_channels(self) -> list[zha_typing.ChannelType]: def unclaimed_channels(self) -> list[base.ZigbeeChannel]:
"""Return a list of available (unclaimed) channels.""" """Return a list of available (unclaimed) channels."""
claimed = set(self.claimed_channels) claimed = set(self.claimed_channels)
available = set(self.all_channels) available = set(self.all_channels)