Add zha typing [core.gateway] (1) (#68397)

* Add zha typing [core.gateway] (1)

* Add temporary type ignores

* Fix pylint
This commit is contained in:
Marc Mueller 2022-03-22 15:13:09 +01:00 committed by GitHub
parent 94cd656670
commit bdc92271f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 53 deletions

View File

@ -109,8 +109,8 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
device_registry = await hass.helpers.device_registry.async_get_registry()
device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(CONNECTION_ZIGBEE, str(zha_gateway.application_controller.ieee))},
identifiers={(DOMAIN, str(zha_gateway.application_controller.ieee))},
connections={(CONNECTION_ZIGBEE, str(zha_gateway.application_controller.ieee))}, # type: ignore[attr-defined]
identifiers={(DOMAIN, str(zha_gateway.application_controller.ieee))}, # type: ignore[attr-defined]
name="Zigbee Coordinator",
manufacturer="ZHA",
model=zha_gateway.radio_description,

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import collections
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
import itertools
@ -10,13 +11,18 @@ import logging
import os
import time
import traceback
from typing import TYPE_CHECKING, Any, Union
from serial import SerialException
from zigpy.config import CONF_DEVICE
import zigpy.device as zigpy_dev
import zigpy.device
import zigpy.endpoint
import zigpy.group
from zigpy.types.named import EUI64
from homeassistant.components.system_log import LogEntry, _figure_out_source
from homeassistant.core import callback
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.device_registry import (
CONNECTION_ZIGBEE,
@ -28,8 +34,9 @@ from homeassistant.helpers.entity_registry import (
async_get_registry as get_ent_reg,
)
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType
from . import discovery, typing as zha_typing
from . import discovery
from .const import (
ATTR_IEEE,
ATTR_MANUFACTURER,
@ -81,7 +88,13 @@ from .device import DeviceStatus, ZHADevice
from .group import GroupMember, ZHAGroup
from .registries import GROUP_ENTITY_DOMAINS
from .store import async_get_registry
from .typing import ZhaGroupType, ZigpyEndpointType, ZigpyGroupType
if TYPE_CHECKING:
from logging import Filter, LogRecord
from ..entity import ZhaEntity
_LogFilterType = Union[Filter, Callable[[LogRecord], int]]
_LOGGER = logging.getLogger(__name__)
@ -103,29 +116,33 @@ class DevicePairingStatus(Enum):
class ZHAGateway:
"""Gateway that handles events that happen on the ZHA Zigbee network."""
def __init__(self, hass, config, config_entry):
def __init__(
self, hass: HomeAssistant, config: ConfigType, config_entry: ConfigEntry
) -> None:
"""Initialize the gateway."""
self._hass = hass
self._config = config
self._devices = {}
self._groups = {}
self.coordinator_zha_device = None
self._device_registry = collections.defaultdict(list)
self._devices: dict[EUI64, ZHADevice] = {}
self._groups: dict[int, ZHAGroup] = {}
self.coordinator_zha_device: ZHADevice | None = None
self._device_registry: collections.defaultdict[
EUI64, list[EntityReference]
] = collections.defaultdict(list)
self.zha_storage = None
self.ha_device_registry = None
self.ha_entity_registry = None
self.application_controller = None
self.radio_description = None
self._log_levels = {
self._log_levels: dict[str, dict[str, int]] = {
DEBUG_LEVEL_ORIGINAL: async_capture_log_levels(),
DEBUG_LEVEL_CURRENT: async_capture_log_levels(),
}
self.debug_enabled = False
self._log_relay_handler = LogRelayHandler(hass, self)
self.config_entry = config_entry
self._unsubs = []
self._unsubs: list[Callable[[], None]] = []
async def async_initialize(self):
async def async_initialize(self) -> None:
"""Initialize controller and connect radio."""
discovery.PROBE.initialize(self._hass)
discovery.GROUP_PROBE.initialize(self._hass)
@ -211,7 +228,7 @@ class ZHAGateway:
"""Initialize devices and load entities."""
semaphore = asyncio.Semaphore(2)
async def _throttle(zha_device: zha_typing.ZhaDeviceType, cached: bool):
async def _throttle(zha_device: ZHADevice, cached: bool) -> None:
async with semaphore:
await zha_device.async_initialize(from_cache=cached)
@ -233,7 +250,7 @@ class ZHAGateway:
)
)
def device_joined(self, device):
def device_joined(self, device: zigpy.device.Device) -> None:
"""Handle device joined.
At this point, no information about the device is known other than its
@ -252,7 +269,7 @@ class ZHAGateway:
},
)
def raw_device_initialized(self, device):
def raw_device_initialized(self, device: zigpy.device.Device) -> None:
"""Handle a device initialization without quirks loaded."""
manuf = device.manufacturer
async_dispatcher_send(
@ -271,16 +288,16 @@ class ZHAGateway:
},
)
def device_initialized(self, device):
def device_initialized(self, device: zigpy.device.Device) -> None:
"""Handle device joined and basic information discovered."""
self._hass.async_create_task(self.async_device_initialized(device))
def device_left(self, device: zigpy_dev.Device):
def device_left(self, device: zigpy.device.Device) -> None:
"""Handle device leaving the network."""
self.async_update_device(device, False)
def group_member_removed(
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType
self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint
) -> None:
"""Handle zigpy group member removed event."""
# need to handle endpoint correctly on groups
@ -292,7 +309,7 @@ class ZHAGateway:
)
def group_member_added(
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType
self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint
) -> None:
"""Handle zigpy group member added event."""
# need to handle endpoint correctly on groups
@ -306,14 +323,14 @@ class ZHAGateway:
# we need to do this because there wasn't already a group entity to remove and re-add
discovery.GROUP_PROBE.discover_group_entities(zha_group)
def group_added(self, zigpy_group: ZigpyGroupType) -> None:
def group_added(self, zigpy_group: zigpy.group.Group) -> None:
"""Handle zigpy group added event."""
zha_group = self._async_get_or_create_group(zigpy_group)
zha_group.info("group_added")
# need to dispatch for entity creation here
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED)
def group_removed(self, zigpy_group: ZigpyGroupType) -> None:
def group_removed(self, zigpy_group: zigpy.group.Group) -> None:
"""Handle zigpy group removed event."""
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED)
zha_group = self._groups.pop(zigpy_group.group_id, None)
@ -321,7 +338,7 @@ class ZHAGateway:
self._cleanup_group_entity_registry_entries(zigpy_group)
def _send_group_gateway_message(
self, zigpy_group: ZigpyGroupType, gateway_message_type: str
self, zigpy_group: zigpy.group.Group, gateway_message_type: str
) -> None:
"""Send the gateway event for a zigpy group event."""
zha_group = self._groups.get(zigpy_group.group_id)
@ -335,7 +352,9 @@ class ZHAGateway:
},
)
async def _async_remove_device(self, device, entity_refs):
async def _async_remove_device(
self, device: ZHADevice, entity_refs: list[EntityReference] | None
) -> None:
if entity_refs is not None:
remove_tasks = []
for entity_ref in entity_refs:
@ -346,7 +365,7 @@ class ZHAGateway:
if reg_device is not None:
self.ha_device_registry.async_remove_device(reg_device.id)
def device_removed(self, device):
def device_removed(self, device: zigpy.device.Device) -> None:
"""Handle device being removed from the network."""
zha_device = self._devices.pop(device.ieee, None)
entity_refs = self._device_registry.pop(device.ieee, None)
@ -365,23 +384,23 @@ class ZHAGateway:
},
)
def get_device(self, ieee):
def get_device(self, ieee: EUI64) -> ZHADevice | None:
"""Return ZHADevice for given ieee."""
return self._devices.get(ieee)
def get_group(self, group_id: str) -> ZhaGroupType | None:
def get_group(self, group_id: int) -> ZHAGroup | None:
"""Return Group for given group id."""
return self.groups.get(group_id)
@callback
def async_get_group_by_name(self, group_name: str) -> ZhaGroupType | None:
def async_get_group_by_name(self, group_name: str) -> ZHAGroup | None:
"""Get ZHA group by name."""
for group in self.groups.values():
if group.name == group_name:
return group
return None
def get_entity_reference(self, entity_id):
def get_entity_reference(self, entity_id: str) -> EntityReference | None:
"""Return entity reference for given entity_id if found."""
for entity_reference in itertools.chain.from_iterable(
self.device_registry.values()
@ -389,7 +408,7 @@ class ZHAGateway:
if entity_id == entity_reference.reference_id:
return entity_reference
def remove_entity_reference(self, entity):
def remove_entity_reference(self, entity: ZhaEntity) -> None:
"""Remove entity reference for given entity_id if found."""
if entity.zha_device.ieee in self.device_registry:
entity_refs = self.device_registry.get(entity.zha_device.ieee)
@ -398,7 +417,7 @@ class ZHAGateway:
]
def _cleanup_group_entity_registry_entries(
self, zigpy_group: ZigpyGroupType
self, zigpy_group: zigpy.group.Group
) -> None:
"""Remove entity registry entries for group entities when the groups are removed from HA."""
# first we collect the potential unique ids for entities that could be created from this group
@ -429,17 +448,17 @@ class ZHAGateway:
self.ha_entity_registry.async_remove(entry.entity_id)
@property
def devices(self):
def devices(self) -> dict[EUI64, ZHADevice]:
"""Return devices."""
return self._devices
@property
def groups(self):
def groups(self) -> dict[int, ZHAGroup]:
"""Return groups."""
return self._groups
@property
def device_registry(self):
def device_registry(self) -> collections.defaultdict[EUI64, list[EntityReference]]:
"""Return entities by ieee."""
return self._device_registry
@ -464,7 +483,7 @@ class ZHAGateway:
)
@callback
def async_enable_debug_mode(self, filterer=None):
def async_enable_debug_mode(self, filterer: _LogFilterType | None = None) -> None:
"""Enable debug mode for ZHA."""
self._log_levels[DEBUG_LEVEL_ORIGINAL] = async_capture_log_levels()
async_set_logger_levels(DEBUG_LEVELS)
@ -479,7 +498,7 @@ class ZHAGateway:
self.debug_enabled = True
@callback
def async_disable_debug_mode(self, filterer=None):
def async_disable_debug_mode(self, filterer: _LogFilterType | None = None) -> None:
"""Disable debug mode for ZHA."""
async_set_logger_levels(self._log_levels[DEBUG_LEVEL_ORIGINAL])
self._log_levels[DEBUG_LEVEL_CURRENT] = async_capture_log_levels()
@ -491,8 +510,8 @@ class ZHAGateway:
@callback
def _async_get_or_create_device(
self, zigpy_device: zha_typing.ZigpyDeviceType, restored: bool = False
):
self, zigpy_device: zigpy.device.Device, restored: bool = False
) -> ZHADevice:
"""Get or create a ZHA device."""
if (zha_device := self._devices.get(zigpy_device.ieee)) is None:
zha_device = ZHADevice.new(self._hass, zigpy_device, self, restored)
@ -511,7 +530,7 @@ class ZHAGateway:
return zha_device
@callback
def _async_get_or_create_group(self, zigpy_group: ZigpyGroupType) -> ZhaGroupType:
def _async_get_or_create_group(self, zigpy_group: zigpy.group.Group) -> ZHAGroup:
"""Get or create a ZHA group."""
zha_group = self._groups.get(zigpy_group.group_id)
if zha_group is None:
@ -521,7 +540,7 @@ class ZHAGateway:
@callback
def async_update_device(
self, sender: zigpy_dev.Device, available: bool = True
self, sender: zigpy.device.Device, available: bool = True
) -> None:
"""Update device that has just become available."""
if sender.ieee in self.devices:
@ -530,12 +549,12 @@ class ZHAGateway:
if device.status is DeviceStatus.INITIALIZED:
device.update_available(available)
async def async_update_device_storage(self, *_):
async def async_update_device_storage(self, *_: Any) -> None:
"""Update the devices in the store."""
for device in self.devices.values():
self.zha_storage.async_update_device(device)
async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType):
async def async_device_initialized(self, device: zigpy.device.Device) -> None:
"""Handle device joined and basic information discovered (async)."""
zha_device = self._async_get_or_create_device(device)
# This is an active device so set a last seen if it is none
@ -576,7 +595,7 @@ class ZHAGateway:
},
)
async def _async_device_joined(self, zha_device: zha_typing.ZhaDeviceType) -> None:
async def _async_device_joined(self, zha_device: ZHADevice) -> None:
zha_device.available = True
device_info = zha_device.device_info
await zha_device.async_configure()
@ -592,7 +611,7 @@ class ZHAGateway:
await zha_device.async_initialize(from_cache=False)
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
async def _async_device_rejoined(self, zha_device):
async def _async_device_rejoined(self, zha_device: ZHADevice) -> None:
_LOGGER.debug(
"skipping discovery for previously discovered device - %s:%s",
zha_device.nwk,
@ -615,8 +634,11 @@ class ZHAGateway:
zha_device.update_available(True)
async def async_create_zigpy_group(
self, name: str, members: list[GroupMember], group_id: int = None
) -> ZhaGroupType:
self,
name: str,
members: list[GroupMember] | None,
group_id: int | None = None,
) -> ZHAGroup | None:
"""Create a new Zigpy Zigbee group."""
# we start with two to fill any gaps from a user removing existing groups
@ -659,7 +681,7 @@ class ZHAGateway:
await asyncio.gather(*tasks)
self.application_controller.groups.pop(group_id)
async def shutdown(self):
async def shutdown(self) -> None:
"""Stop ZHA Controller Application."""
_LOGGER.debug("Shutting down ZHA ControllerApplication")
for unsubscribe in self._unsubs:
@ -668,7 +690,7 @@ class ZHAGateway:
def handle_message(
self,
sender: zigpy_dev.Device,
sender: zigpy.device.Device,
profile: int,
cluster: int,
src_ep: int,
@ -681,7 +703,7 @@ class ZHAGateway:
@callback
def async_capture_log_levels():
def async_capture_log_levels() -> dict[str, int]:
"""Capture current logger levels for ZHA."""
return {
DEBUG_COMP_BELLOWS: logging.getLogger(DEBUG_COMP_BELLOWS).getEffectiveLevel(),
@ -703,7 +725,7 @@ def async_capture_log_levels():
@callback
def async_set_logger_levels(levels):
def async_set_logger_levels(levels: dict[str, int]) -> None:
"""Set logger levels for ZHA."""
logging.getLogger(DEBUG_COMP_BELLOWS).setLevel(levels[DEBUG_COMP_BELLOWS])
logging.getLogger(DEBUG_COMP_ZHA).setLevel(levels[DEBUG_COMP_ZHA])
@ -717,13 +739,13 @@ def async_set_logger_levels(levels):
class LogRelayHandler(logging.Handler):
"""Log handler for error messages."""
def __init__(self, hass, gateway):
def __init__(self, hass: HomeAssistant, gateway: ZHAGateway) -> None:
"""Initialize a new LogErrorHandler."""
super().__init__()
self.hass = hass
self.gateway = gateway
def emit(self, record):
def emit(self, record: LogRecord) -> None:
"""Relay log message via dispatcher."""
stack = []
if record.levelno >= logging.WARN and not record.exc_info: