diff --git a/homeassistant/components/zha/core/decorators.py b/homeassistant/components/zha/core/decorators.py index a27e4cc0bfc..c57cad7d65e 100644 --- a/homeassistant/components/zha/core/decorators.py +++ b/homeassistant/components/zha/core/decorators.py @@ -2,37 +2,32 @@ from __future__ import annotations from collections.abc import Callable -from typing import TypeVar +from typing import Any, TypeVar, Union -CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name +_TypeT = TypeVar("_TypeT", bound=type[Any]) -class DictRegistry(dict): +class DictRegistry(dict[Union[int, str], _TypeT]): """Dict Registry of items.""" - def register( - self, name: int | str, item: str | CALLABLE_T = None - ) -> Callable[[CALLABLE_T], CALLABLE_T]: + def register(self, name: int | str) -> Callable[[_TypeT], _TypeT]: """Return decorator to register item with a specific name.""" - def decorator(channel: CALLABLE_T) -> CALLABLE_T: + def decorator(channel: _TypeT) -> _TypeT: """Register decorated channel or item.""" - if item is None: - self[name] = channel - else: - self[name] = item + self[name] = channel return channel return decorator -class SetRegistry(set): +class SetRegistry(set[Union[int, str]]): """Set Registry of items.""" - def register(self, name: int | str) -> Callable[[CALLABLE_T], CALLABLE_T]: + def register(self, name: int | str) -> Callable[[_TypeT], _TypeT]: """Return decorator to register item with a specific name.""" - def decorator(channel: CALLABLE_T) -> CALLABLE_T: + def decorator(channel: _TypeT) -> _TypeT: """Register decorated channel or item.""" self.add(name) return channel diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 1480469ce2c..1d3482cd8f4 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -4,6 +4,7 @@ from __future__ import annotations import collections from collections.abc import Callable import dataclasses +from typing import TYPE_CHECKING import attr from zigpy import zcl @@ -15,8 +16,11 @@ from homeassistant.const import Platform # importing channels updates registries from . import channels as zha_channels # noqa: F401 pylint: disable=unused-import -from .decorators import CALLABLE_T, DictRegistry, SetRegistry -from .typing import ChannelType +from .decorators import DictRegistry, SetRegistry +from .typing import CALLABLE_T, ChannelType + +if TYPE_CHECKING: + from .channels.base import ClientChannel, ZigbeeChannel GROUP_ENTITY_DOMAINS = [Platform.LIGHT, Platform.SWITCH, Platform.FAN] @@ -98,8 +102,8 @@ DEVICE_CLASS = { } DEVICE_CLASS = collections.defaultdict(dict, DEVICE_CLASS) -CLIENT_CHANNELS_REGISTRY = DictRegistry() -ZIGBEE_CHANNEL_REGISTRY = DictRegistry() +CLIENT_CHANNELS_REGISTRY: DictRegistry[type[ClientChannel]] = DictRegistry() +ZIGBEE_CHANNEL_REGISTRY: DictRegistry[type[ZigbeeChannel]] = DictRegistry() def set_or_callable(value):