Add zha typing [core.gateway] (3) (#68685)

This commit is contained in:
Marc Mueller 2022-03-28 23:58:06 +02:00 committed by GitHub
parent f2aee38841
commit f0e2f964e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 23 deletions

View File

@ -11,7 +11,7 @@ import logging
import os
import time
import traceback
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, NamedTuple, Union
from serial import SerialException
from zigpy.application import ControllerApplication
@ -31,6 +31,7 @@ from homeassistant.helpers.device_registry import (
async_get_registry as get_dev_reg,
)
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_registry import (
EntityRegistry,
async_entries_for_device,
@ -96,16 +97,22 @@ if TYPE_CHECKING:
from logging import Filter, LogRecord
from ..entity import ZhaEntity
from .channels.base import ZigbeeChannel
from .store import ZhaStorage
_LogFilterType = Union[Filter, Callable[[LogRecord], int]]
_LOGGER = logging.getLogger(__name__)
EntityReference = collections.namedtuple(
"EntityReference",
"reference_id zha_device cluster_channels device_info remove_future",
)
class EntityReference(NamedTuple):
"""Describes an entity reference."""
reference_id: str
zha_device: ZHADevice
cluster_channels: dict[str, ZigbeeChannel]
device_info: DeviceInfo
remove_future: asyncio.Future[Any]
class DevicePairingStatus(Enum):
@ -362,7 +369,7 @@ class ZHAGateway:
self, device: ZHADevice, entity_refs: list[EntityReference] | None
) -> None:
if entity_refs is not None:
remove_tasks = []
remove_tasks: list[asyncio.Future[Any]] = []
for entity_ref in entity_refs:
remove_tasks.append(entity_ref.remove_future)
if remove_tasks:
@ -413,13 +420,14 @@ class ZHAGateway:
):
if entity_id == entity_reference.reference_id:
return entity_reference
return None
def remove_entity_reference(self, entity: ZhaEntity) -> None:
"""Remove entity reference for given entity_id if found."""
if entity.zha_device.ieee in self.device_registry:
entity_refs = self.device_registry.get(entity.zha_device.ieee)
self.device_registry[entity.zha_device.ieee] = [
e for e in entity_refs if e.reference_id != entity.entity_id
e for e in entity_refs if e.reference_id != entity.entity_id # type: ignore[union-attr]
]
def _cleanup_group_entity_registry_entries(
@ -470,12 +478,12 @@ class ZHAGateway:
def register_entity_reference(
self,
ieee,
reference_id,
zha_device,
cluster_channels,
device_info,
remove_future,
ieee: EUI64,
reference_id: str,
zha_device: ZHADevice,
cluster_channels: dict[str, ZigbeeChannel],
device_info: DeviceInfo,
remove_future: asyncio.Future[Any],
):
"""Record the creation of a hass entity associated with ieee."""
self._device_registry[ieee].append(

View File

@ -2,10 +2,9 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable
import functools
import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from homeassistant.const import ATTR_NAME
from homeassistant.core import CALLBACK_TYPE, Event, callback
@ -32,6 +31,10 @@ from .core.const import (
from .core.helpers import LogMixin
from .core.typing import CALLABLE_T, ChannelType, ZhaDeviceType
if TYPE_CHECKING:
from .core.channels.base import ZigbeeChannel
from .core.device import ZHADevice
_LOGGER = logging.getLogger(__name__)
ENTITY_SUFFIX = "entity_suffix"
@ -43,7 +46,7 @@ class BaseZhaEntity(LogMixin, entity.Entity):
unique_id_suffix: str | None = None
def __init__(self, unique_id: str, zha_device: ZhaDeviceType, **kwargs) -> None:
def __init__(self, unique_id: str, zha_device: ZHADevice, **kwargs: Any) -> None:
"""Init ZHA entity."""
self._name: str = ""
self._force_update: bool = False
@ -53,9 +56,9 @@ class BaseZhaEntity(LogMixin, entity.Entity):
self._unique_id += f"-{self.unique_id_suffix}"
self._state: Any = None
self._extra_state_attributes: dict[str, Any] = {}
self._zha_device: ZhaDeviceType = zha_device
self._zha_device = zha_device
self._unsubs: list[CALLABLE_T] = []
self.remove_future: Awaitable[None] = None
self.remove_future: asyncio.Future[Any] = asyncio.Future()
@property
def name(self) -> str:
@ -68,7 +71,7 @@ class BaseZhaEntity(LogMixin, entity.Entity):
return self._unique_id
@property
def zha_device(self) -> ZhaDeviceType:
def zha_device(self) -> ZHADevice:
"""Return the zha device this entity is attached to."""
return self._zha_device
@ -159,9 +162,9 @@ class ZhaEntity(BaseZhaEntity, RestoreEntity):
def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
zha_device: ZHADevice,
channels: list[ZigbeeChannel],
**kwargs: Any,
) -> None:
"""Init ZHA entity."""
super().__init__(unique_id, zha_device, **kwargs)
@ -170,7 +173,7 @@ class ZhaEntity(BaseZhaEntity, RestoreEntity):
self._name: str = f"{zha_device.name} {ieeetail} {ch_names}"
if self.unique_id_suffix:
self._name += f" {self.unique_id_suffix}"
self.cluster_channels: dict[str, ChannelType] = {}
self.cluster_channels: dict[str, ZigbeeChannel] = {}
for channel in channels:
self.cluster_channels[channel.name] = channel