From 126320529e2153c11c2b3897689180a103e25712 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 19 Mar 2022 12:37:04 +0100 Subject: [PATCH] Add zha typing [core.discovery] (1) (#68359) * Add zha typing [core.discovery] (1) * Fix circular import --- .../components/zha/core/discovery.py | 77 +++++++++++-------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 26323793e13..9f7523d41f0 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections import Counter from collections.abc import Callable import logging +from typing import TYPE_CHECKING from homeassistant import const as ha_const from homeassistant.core import HomeAssistant, callback @@ -11,9 +12,11 @@ from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, ) +from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_registry import async_entries_for_device +from homeassistant.helpers.typing import ConfigType -from . import const as zha_const, registries as zha_regs, typing as zha_typing +from . import const as zha_const, registries as zha_regs from .. import ( # noqa: F401 pylint: disable=unused-import, alarm_control_panel, binary_sensor, @@ -32,16 +35,23 @@ from .. import ( # noqa: F401 pylint: disable=unused-import, ) from .channels import base +if TYPE_CHECKING: + from ..entity import ZhaEntity + from .channels import ChannelPool + from .device import ZHADevice + from .gateway import ZHAGateway + from .group import ZHAGroup + _LOGGER = logging.getLogger(__name__) @callback async def async_add_entities( - _async_add_entities: Callable, + _async_add_entities: AddEntitiesCallback, entities: list[ tuple[ - zha_typing.ZhaEntityType, - tuple[str, zha_typing.ZhaDeviceType, list[zha_typing.ChannelType]], + type[ZhaEntity], + tuple[str, ZHADevice, list[base.ZigbeeChannel]], ] ], update_before_add: bool = True, @@ -50,20 +60,20 @@ async def async_add_entities( if not entities: return to_add = [ent_cls.create_entity(*args) for ent_cls, args in entities] - to_add = [entity for entity in to_add if entity is not None] - _async_add_entities(to_add, update_before_add=update_before_add) + entities_to_add = [entity for entity in to_add if entity is not None] + _async_add_entities(entities_to_add, update_before_add=update_before_add) entities.clear() class ProbeEndpoint: """All discovered channels and entities of an endpoint.""" - def __init__(self): + def __init__(self) -> None: """Initialize instance.""" - self._device_configs = {} + self._device_configs: ConfigType = {} @callback - def discover_entities(self, channel_pool: zha_typing.ChannelPoolType) -> None: + def discover_entities(self, channel_pool: ChannelPool) -> None: """Process an endpoint on a zigpy device.""" self.discover_by_device_type(channel_pool) self.discover_multi_entities(channel_pool) @@ -71,12 +81,14 @@ class ProbeEndpoint: zha_regs.ZHA_ENTITIES.clean_up() @callback - def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None: + def discover_by_device_type(self, channel_pool: ChannelPool) -> None: """Process an endpoint on a zigpy device.""" unique_id = channel_pool.unique_id - component = self._device_configs.get(unique_id, {}).get(ha_const.CONF_TYPE) + component: str | None = self._device_configs.get(unique_id, {}).get( + ha_const.CONF_TYPE + ) if component is None: ep_profile_id = channel_pool.endpoint.profile_id ep_device_type = channel_pool.endpoint.device_type @@ -93,7 +105,7 @@ class ProbeEndpoint: channel_pool.async_new_entity(component, entity_class, unique_id, claimed) @callback - def discover_by_cluster_id(self, channel_pool: zha_typing.ChannelPoolType) -> None: + def discover_by_cluster_id(self, channel_pool: ChannelPool) -> None: """Process an endpoint on a zigpy device.""" items = zha_regs.SINGLE_INPUT_CLUSTER_DEVICE_CLASS.items() @@ -125,8 +137,8 @@ class ProbeEndpoint: @staticmethod def probe_single_cluster( component: str, - channel: zha_typing.ChannelType, - ep_channels: zha_typing.ChannelPoolType, + channel: base.ZigbeeChannel, + ep_channels: ChannelPool, ) -> None: """Probe specified cluster for specific component.""" if component is None or component not in zha_const.PLATFORMS: @@ -142,9 +154,7 @@ class ProbeEndpoint: ep_channels.claim_channels(claimed) ep_channels.async_new_entity(component, entity_class, unique_id, claimed) - def handle_on_off_output_cluster_exception( - self, ep_channels: zha_typing.ChannelPoolType - ) -> None: + def handle_on_off_output_cluster_exception(self, ep_channels: ChannelPool) -> None: """Process output clusters of the endpoint.""" profile_id = ep_channels.endpoint.profile_id @@ -167,7 +177,7 @@ class ProbeEndpoint: @staticmethod @callback - def discover_multi_entities(channel_pool: zha_typing.ChannelPoolType) -> None: + def discover_multi_entities(channel_pool: ChannelPool) -> None: """Process an endpoint on and discover multiple entities.""" ep_profile_id = channel_pool.endpoint.profile_id @@ -209,7 +219,9 @@ class ProbeEndpoint: def initialize(self, hass: HomeAssistant) -> None: """Update device overrides config.""" - zha_config = hass.data[zha_const.DATA_ZHA].get(zha_const.DATA_ZHA_CONFIG, {}) + zha_config: ConfigType = hass.data[zha_const.DATA_ZHA].get( + zha_const.DATA_ZHA_CONFIG, {} + ) if overrides := zha_config.get(zha_const.CONF_DEVICE_CONFIG): self._device_configs.update(overrides) @@ -217,10 +229,11 @@ class ProbeEndpoint: class GroupProbe: """Determine the appropriate component for a group.""" - def __init__(self): + _hass: HomeAssistant + + def __init__(self) -> None: """Initialize instance.""" - self._hass = None - self._unsubs = [] + self._unsubs: list[Callable[[], None]] = [] def initialize(self, hass: HomeAssistant) -> None: """Initialize the group probe.""" @@ -231,7 +244,7 @@ class GroupProbe: ) ) - def cleanup(self): + def cleanup(self) -> None: """Clean up on when zha shuts down.""" for unsub in self._unsubs[:]: unsub() @@ -240,13 +253,15 @@ class GroupProbe: @callback def _reprobe_group(self, group_id: int) -> None: """Reprobe a group for entities after its members change.""" - zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][ + zha_const.DATA_ZHA_GATEWAY + ] if (zha_group := zha_gateway.groups.get(group_id)) is None: return self.discover_group_entities(zha_group) @callback - def discover_group_entities(self, group: zha_typing.ZhaGroupType) -> None: + def discover_group_entities(self, group: ZHAGroup) -> None: """Process a group and create any entities that are needed.""" # only create a group entity if there are 2 or more members in a group if len(group.members) < 2: @@ -262,7 +277,9 @@ class GroupProbe: if not entity_domains: return - zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][ + zha_const.DATA_ZHA_GATEWAY + ] for domain in entity_domains: entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain) if entity_class is None: @@ -281,12 +298,12 @@ class GroupProbe: async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES) @staticmethod - def determine_entity_domains( - hass: HomeAssistant, group: zha_typing.ZhaGroupType - ) -> list[str]: + def determine_entity_domains(hass: HomeAssistant, group: ZHAGroup) -> list[str]: """Determine the entity domains for this group.""" entity_domains: list[str] = [] - zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = hass.data[zha_const.DATA_ZHA][ + zha_const.DATA_ZHA_GATEWAY + ] all_domain_occurrences = [] for member in group.members: if member.device.is_coordinator: