From b6d3e34ebc2a7348584052073ae40025a4bd7bf9 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Mon, 20 Jun 2022 14:50:27 +0200 Subject: [PATCH] Drop custom type (CALLABLE_T) from zha (#73736) * Drop CALLABLE_T from zha * Adjust .coveragerc * Apply suggestions from code review Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> * Add TypeVar * Apply suggestions from code review Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> * One more Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> * Flake8 Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> --- .coveragerc | 1 - .../components/zha/core/channels/security.py | 8 ++-- .../components/zha/core/registries.py | 42 +++++++++++-------- homeassistant/components/zha/core/typing.py | 6 --- homeassistant/components/zha/entity.py | 6 +-- 5 files changed, 31 insertions(+), 32 deletions(-) delete mode 100644 homeassistant/components/zha/core/typing.py diff --git a/.coveragerc b/.coveragerc index 46cbed3a0dc..eba89e5f238 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1516,7 +1516,6 @@ omit = homeassistant/components/zha/core/gateway.py homeassistant/components/zha/core/helpers.py homeassistant/components/zha/core/registries.py - homeassistant/components/zha/core/typing.py homeassistant/components/zha/entity.py homeassistant/components/zha/light.py homeassistant/components/zha/sensor.py diff --git a/homeassistant/components/zha/core/channels/security.py b/homeassistant/components/zha/core/channels/security.py index 4c0d6bbfd59..789e792e149 100644 --- a/homeassistant/components/zha/core/channels/security.py +++ b/homeassistant/components/zha/core/channels/security.py @@ -7,7 +7,8 @@ https://home-assistant.io/integrations/zha/ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from zigpy.exceptions import ZigbeeException import zigpy.zcl @@ -25,7 +26,6 @@ from ..const import ( WARNING_DEVICE_STROBE_HIGH, WARNING_DEVICE_STROBE_YES, ) -from ..typing import CALLABLE_T from .base import ChannelStatus, ZigbeeChannel if TYPE_CHECKING: @@ -55,7 +55,7 @@ class IasAce(ZigbeeChannel): def __init__(self, cluster: zigpy.zcl.Cluster, ch_pool: ChannelPool) -> None: """Initialize IAS Ancillary Control Equipment channel.""" super().__init__(cluster, ch_pool) - self.command_map: dict[int, CALLABLE_T] = { + self.command_map: dict[int, Callable[..., Any]] = { IAS_ACE_ARM: self.arm, IAS_ACE_BYPASS: self._bypass, IAS_ACE_EMERGENCY: self._emergency, @@ -67,7 +67,7 @@ class IasAce(ZigbeeChannel): IAS_ACE_GET_BYPASSED_ZONE_LIST: self._get_bypassed_zone_list, IAS_ACE_GET_ZONE_STATUS: self._get_zone_status, } - self.arm_map: dict[AceCluster.ArmMode, CALLABLE_T] = { + self.arm_map: dict[AceCluster.ArmMode, Callable[..., Any]] = { AceCluster.ArmMode.Disarm: self._disarm, AceCluster.ArmMode.Arm_All_Zones: self._arm_away, AceCluster.ArmMode.Arm_Day_Home_Only: self._arm_day, diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index ed6b047566c..7e2114b5911 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -4,7 +4,7 @@ from __future__ import annotations import collections from collections.abc import Callable import dataclasses -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar import attr from zigpy import zcl @@ -17,11 +17,15 @@ from homeassistant.const import Platform # importing channels updates registries from . import channels as zha_channels # noqa: F401 pylint: disable=unused-import from .decorators import DictRegistry, SetRegistry -from .typing import CALLABLE_T if TYPE_CHECKING: + from ..entity import ZhaEntity, ZhaGroupEntity from .channels.base import ClientChannel, ZigbeeChannel + +_ZhaEntityT = TypeVar("_ZhaEntityT", bound=type["ZhaEntity"]) +_ZhaGroupEntityT = TypeVar("_ZhaGroupEntityT", bound=type["ZhaGroupEntity"]) + GROUP_ENTITY_DOMAINS = [Platform.LIGHT, Platform.SWITCH, Platform.FAN] PHILLIPS_REMOTE_CLUSTER = 0xFC00 @@ -215,7 +219,7 @@ class MatchRule: class EntityClassAndChannels: """Container for entity class and corresponding channels.""" - entity_class: CALLABLE_T + entity_class: type[ZhaEntity] claimed_channel: list[ZigbeeChannel] @@ -225,19 +229,19 @@ class ZHAEntityRegistry: def __init__(self): """Initialize Registry instance.""" self._strict_registry: dict[ - str, dict[MatchRule, CALLABLE_T] + str, dict[MatchRule, type[ZhaEntity]] ] = collections.defaultdict(dict) self._multi_entity_registry: dict[ - str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]] + str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] ] = collections.defaultdict( lambda: collections.defaultdict(lambda: collections.defaultdict(list)) ) self._config_diagnostic_entity_registry: dict[ - str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]] + str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] ] = collections.defaultdict( lambda: collections.defaultdict(lambda: collections.defaultdict(list)) ) - self._group_registry: dict[str, CALLABLE_T] = {} + self._group_registry: dict[str, type[ZhaGroupEntity]] = {} self.single_device_matches: dict[ Platform, dict[EUI64, list[str]] ] = collections.defaultdict(lambda: collections.defaultdict(list)) @@ -248,8 +252,8 @@ class ZHAEntityRegistry: manufacturer: str, model: str, channels: list[ZigbeeChannel], - default: CALLABLE_T = None, - ) -> tuple[CALLABLE_T, list[ZigbeeChannel]]: + default: type[ZhaEntity] | None = None, + ) -> tuple[type[ZhaEntity] | None, list[ZigbeeChannel]]: """Match a ZHA Channels to a ZHA Entity class.""" matches = self._strict_registry[component] for match in sorted(matches, key=lambda x: x.weight, reverse=True): @@ -310,7 +314,7 @@ class ZHAEntityRegistry: return result, list(all_claimed) - def get_group_entity(self, component: str) -> CALLABLE_T: + def get_group_entity(self, component: str) -> type[ZhaGroupEntity] | None: """Match a ZHA group to a ZHA Entity class.""" return self._group_registry.get(component) @@ -322,14 +326,14 @@ class ZHAEntityRegistry: manufacturers: Callable | set[str] | str = None, models: Callable | set[str] | str = None, aux_channels: Callable | set[str] | str = None, - ) -> Callable[[CALLABLE_T], CALLABLE_T]: + ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a strict match rule.""" rule = MatchRule( channel_names, generic_ids, manufacturers, models, aux_channels ) - def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T: + def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT: """Register a strict match rule. All non empty fields of a match rule must match. @@ -348,7 +352,7 @@ class ZHAEntityRegistry: models: Callable | set[str] | str = None, aux_channels: Callable | set[str] | str = None, stop_on_match_group: int | str | None = None, - ) -> Callable[[CALLABLE_T], CALLABLE_T]: + ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a loose match rule.""" rule = MatchRule( @@ -359,7 +363,7 @@ class ZHAEntityRegistry: aux_channels, ) - def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T: + def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT: """Register a loose match rule. All non empty fields of a match rule must match. @@ -381,7 +385,7 @@ class ZHAEntityRegistry: models: Callable | set[str] | str = None, aux_channels: Callable | set[str] | str = None, stop_on_match_group: int | str | None = None, - ) -> Callable[[CALLABLE_T], CALLABLE_T]: + ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a loose match rule.""" rule = MatchRule( @@ -392,7 +396,7 @@ class ZHAEntityRegistry: aux_channels, ) - def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T: + def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT: """Register a loose match rule. All non empty fields of a match rule must match. @@ -405,10 +409,12 @@ class ZHAEntityRegistry: return decorator - def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]: + def group_match( + self, component: str + ) -> Callable[[_ZhaGroupEntityT], _ZhaGroupEntityT]: """Decorate a group match rule.""" - def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T: + def decorator(zha_ent: _ZhaGroupEntityT) -> _ZhaGroupEntityT: """Register a group match rule.""" self._group_registry[component] = zha_ent return zha_ent diff --git a/homeassistant/components/zha/core/typing.py b/homeassistant/components/zha/core/typing.py deleted file mode 100644 index 714dc03ef82..00000000000 --- a/homeassistant/components/zha/core/typing.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Typing helpers for ZHA component.""" -from collections.abc import Callable -from typing import TypeVar - -# pylint: disable=invalid-name -CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) diff --git a/homeassistant/components/zha/entity.py b/homeassistant/components/zha/entity.py index 88dc9454f37..fb1a35ff72b 100644 --- a/homeassistant/components/zha/entity.py +++ b/homeassistant/components/zha/entity.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable import functools import logging from typing import TYPE_CHECKING, Any @@ -29,7 +30,6 @@ from .core.const import ( SIGNAL_REMOVE, ) from .core.helpers import LogMixin -from .core.typing import CALLABLE_T if TYPE_CHECKING: from .core.channels.base import ZigbeeChannel @@ -57,7 +57,7 @@ class BaseZhaEntity(LogMixin, entity.Entity): self._state: Any = None self._extra_state_attributes: dict[str, Any] = {} self._zha_device = zha_device - self._unsubs: list[CALLABLE_T] = [] + self._unsubs: list[Callable[[], None]] = [] self.remove_future: asyncio.Future[Any] = asyncio.Future() @property @@ -130,7 +130,7 @@ class BaseZhaEntity(LogMixin, entity.Entity): self, channel: ZigbeeChannel, signal: str, - func: CALLABLE_T, + func: Callable[[], Any], signal_override=False, ): """Accept a signal from a channel."""