Add zha typing [core.gateway] (3) (#68685)

This commit is contained in:
Marc Mueller 2022-03-28 23:58:06 +02:00 committed by GitHub
parent f2aee38841
commit f0e2f964e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 23 deletions

View File

@ -11,7 +11,7 @@ import logging
import os import os
import time import time
import traceback import traceback
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any, NamedTuple, Union
from serial import SerialException from serial import SerialException
from zigpy.application import ControllerApplication from zigpy.application import ControllerApplication
@ -31,6 +31,7 @@ from homeassistant.helpers.device_registry import (
async_get_registry as get_dev_reg, async_get_registry as get_dev_reg,
) )
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_registry import ( from homeassistant.helpers.entity_registry import (
EntityRegistry, EntityRegistry,
async_entries_for_device, async_entries_for_device,
@ -96,16 +97,22 @@ if TYPE_CHECKING:
from logging import Filter, LogRecord from logging import Filter, LogRecord
from ..entity import ZhaEntity from ..entity import ZhaEntity
from .channels.base import ZigbeeChannel
from .store import ZhaStorage from .store import ZhaStorage
_LogFilterType = Union[Filter, Callable[[LogRecord], int]] _LogFilterType = Union[Filter, Callable[[LogRecord], int]]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
EntityReference = collections.namedtuple(
"EntityReference", class EntityReference(NamedTuple):
"reference_id zha_device cluster_channels device_info remove_future", """Describes an entity reference."""
)
reference_id: str
zha_device: ZHADevice
cluster_channels: dict[str, ZigbeeChannel]
device_info: DeviceInfo
remove_future: asyncio.Future[Any]
class DevicePairingStatus(Enum): class DevicePairingStatus(Enum):
@ -362,7 +369,7 @@ class ZHAGateway:
self, device: ZHADevice, entity_refs: list[EntityReference] | None self, device: ZHADevice, entity_refs: list[EntityReference] | None
) -> None: ) -> None:
if entity_refs is not None: if entity_refs is not None:
remove_tasks = [] remove_tasks: list[asyncio.Future[Any]] = []
for entity_ref in entity_refs: for entity_ref in entity_refs:
remove_tasks.append(entity_ref.remove_future) remove_tasks.append(entity_ref.remove_future)
if remove_tasks: if remove_tasks:
@ -413,13 +420,14 @@ class ZHAGateway:
): ):
if entity_id == entity_reference.reference_id: if entity_id == entity_reference.reference_id:
return entity_reference return entity_reference
return None
def remove_entity_reference(self, entity: ZhaEntity) -> None: def remove_entity_reference(self, entity: ZhaEntity) -> None:
"""Remove entity reference for given entity_id if found.""" """Remove entity reference for given entity_id if found."""
if entity.zha_device.ieee in self.device_registry: if entity.zha_device.ieee in self.device_registry:
entity_refs = self.device_registry.get(entity.zha_device.ieee) entity_refs = self.device_registry.get(entity.zha_device.ieee)
self.device_registry[entity.zha_device.ieee] = [ self.device_registry[entity.zha_device.ieee] = [
e for e in entity_refs if e.reference_id != entity.entity_id e for e in entity_refs if e.reference_id != entity.entity_id # type: ignore[union-attr]
] ]
def _cleanup_group_entity_registry_entries( def _cleanup_group_entity_registry_entries(
@ -470,12 +478,12 @@ class ZHAGateway:
def register_entity_reference( def register_entity_reference(
self, self,
ieee, ieee: EUI64,
reference_id, reference_id: str,
zha_device, zha_device: ZHADevice,
cluster_channels, cluster_channels: dict[str, ZigbeeChannel],
device_info, device_info: DeviceInfo,
remove_future, remove_future: asyncio.Future[Any],
): ):
"""Record the creation of a hass entity associated with ieee.""" """Record the creation of a hass entity associated with ieee."""
self._device_registry[ieee].append( self._device_registry[ieee].append(

View File

@ -2,10 +2,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable
import functools import functools
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
from homeassistant.const import ATTR_NAME from homeassistant.const import ATTR_NAME
from homeassistant.core import CALLBACK_TYPE, Event, callback from homeassistant.core import CALLBACK_TYPE, Event, callback
@ -32,6 +31,10 @@ from .core.const import (
from .core.helpers import LogMixin from .core.helpers import LogMixin
from .core.typing import CALLABLE_T, ChannelType, ZhaDeviceType from .core.typing import CALLABLE_T, ChannelType, ZhaDeviceType
if TYPE_CHECKING:
from .core.channels.base import ZigbeeChannel
from .core.device import ZHADevice
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ENTITY_SUFFIX = "entity_suffix" ENTITY_SUFFIX = "entity_suffix"
@ -43,7 +46,7 @@ class BaseZhaEntity(LogMixin, entity.Entity):
unique_id_suffix: str | None = None unique_id_suffix: str | None = None
def __init__(self, unique_id: str, zha_device: ZhaDeviceType, **kwargs) -> None: def __init__(self, unique_id: str, zha_device: ZHADevice, **kwargs: Any) -> None:
"""Init ZHA entity.""" """Init ZHA entity."""
self._name: str = "" self._name: str = ""
self._force_update: bool = False self._force_update: bool = False
@ -53,9 +56,9 @@ class BaseZhaEntity(LogMixin, entity.Entity):
self._unique_id += f"-{self.unique_id_suffix}" self._unique_id += f"-{self.unique_id_suffix}"
self._state: Any = None self._state: Any = None
self._extra_state_attributes: dict[str, Any] = {} self._extra_state_attributes: dict[str, Any] = {}
self._zha_device: ZhaDeviceType = zha_device self._zha_device = zha_device
self._unsubs: list[CALLABLE_T] = [] self._unsubs: list[CALLABLE_T] = []
self.remove_future: Awaitable[None] = None self.remove_future: asyncio.Future[Any] = asyncio.Future()
@property @property
def name(self) -> str: def name(self) -> str:
@ -68,7 +71,7 @@ class BaseZhaEntity(LogMixin, entity.Entity):
return self._unique_id return self._unique_id
@property @property
def zha_device(self) -> ZhaDeviceType: def zha_device(self) -> ZHADevice:
"""Return the zha device this entity is attached to.""" """Return the zha device this entity is attached to."""
return self._zha_device return self._zha_device
@ -159,9 +162,9 @@ class ZhaEntity(BaseZhaEntity, RestoreEntity):
def __init__( def __init__(
self, self,
unique_id: str, unique_id: str,
zha_device: ZhaDeviceType, zha_device: ZHADevice,
channels: list[ChannelType], channels: list[ZigbeeChannel],
**kwargs, **kwargs: Any,
) -> None: ) -> None:
"""Init ZHA entity.""" """Init ZHA entity."""
super().__init__(unique_id, zha_device, **kwargs) super().__init__(unique_id, zha_device, **kwargs)
@ -170,7 +173,7 @@ class ZhaEntity(BaseZhaEntity, RestoreEntity):
self._name: str = f"{zha_device.name} {ieeetail} {ch_names}" self._name: str = f"{zha_device.name} {ieeetail} {ch_names}"
if self.unique_id_suffix: if self.unique_id_suffix:
self._name += f" {self.unique_id_suffix}" self._name += f" {self.unique_id_suffix}"
self.cluster_channels: dict[str, ChannelType] = {} self.cluster_channels: dict[str, ZigbeeChannel] = {}
for channel in channels: for channel in channels:
self.cluster_channels[channel.name] = channel self.cluster_channels[channel.name] = channel