Add zha typing [core.gateway] (1) (#68397)

* Add zha typing [core.gateway] (1)

* Add temporary type ignores

* Fix pylint
This commit is contained in:
Marc Mueller 2022-03-22 15:13:09 +01:00 committed by GitHub
parent 94cd656670
commit bdc92271f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 53 deletions

View File

@ -109,8 +109,8 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
device_registry = await hass.helpers.device_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry()
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id, config_entry_id=config_entry.entry_id,
connections={(CONNECTION_ZIGBEE, str(zha_gateway.application_controller.ieee))}, connections={(CONNECTION_ZIGBEE, str(zha_gateway.application_controller.ieee))}, # type: ignore[attr-defined]
identifiers={(DOMAIN, str(zha_gateway.application_controller.ieee))}, identifiers={(DOMAIN, str(zha_gateway.application_controller.ieee))}, # type: ignore[attr-defined]
name="Zigbee Coordinator", name="Zigbee Coordinator",
manufacturer="ZHA", manufacturer="ZHA",
model=zha_gateway.radio_description, model=zha_gateway.radio_description,

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import collections import collections
from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
import itertools import itertools
@ -10,13 +11,18 @@ import logging
import os import os
import time import time
import traceback import traceback
from typing import TYPE_CHECKING, Any, Union
from serial import SerialException from serial import SerialException
from zigpy.config import CONF_DEVICE from zigpy.config import CONF_DEVICE
import zigpy.device as zigpy_dev import zigpy.device
import zigpy.endpoint
import zigpy.group
from zigpy.types.named import EUI64
from homeassistant.components.system_log import LogEntry, _figure_out_source from homeassistant.components.system_log import LogEntry, _figure_out_source
from homeassistant.core import callback from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.device_registry import ( from homeassistant.helpers.device_registry import (
CONNECTION_ZIGBEE, CONNECTION_ZIGBEE,
@ -28,8 +34,9 @@ from homeassistant.helpers.entity_registry import (
async_get_registry as get_ent_reg, async_get_registry as get_ent_reg,
) )
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType
from . import discovery, typing as zha_typing from . import discovery
from .const import ( from .const import (
ATTR_IEEE, ATTR_IEEE,
ATTR_MANUFACTURER, ATTR_MANUFACTURER,
@ -81,7 +88,13 @@ from .device import DeviceStatus, ZHADevice
from .group import GroupMember, ZHAGroup from .group import GroupMember, ZHAGroup
from .registries import GROUP_ENTITY_DOMAINS from .registries import GROUP_ENTITY_DOMAINS
from .store import async_get_registry from .store import async_get_registry
from .typing import ZhaGroupType, ZigpyEndpointType, ZigpyGroupType
if TYPE_CHECKING:
from logging import Filter, LogRecord
from ..entity import ZhaEntity
_LogFilterType = Union[Filter, Callable[[LogRecord], int]]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -103,29 +116,33 @@ class DevicePairingStatus(Enum):
class ZHAGateway: class ZHAGateway:
"""Gateway that handles events that happen on the ZHA Zigbee network.""" """Gateway that handles events that happen on the ZHA Zigbee network."""
def __init__(self, hass, config, config_entry): def __init__(
self, hass: HomeAssistant, config: ConfigType, config_entry: ConfigEntry
) -> None:
"""Initialize the gateway.""" """Initialize the gateway."""
self._hass = hass self._hass = hass
self._config = config self._config = config
self._devices = {} self._devices: dict[EUI64, ZHADevice] = {}
self._groups = {} self._groups: dict[int, ZHAGroup] = {}
self.coordinator_zha_device = None self.coordinator_zha_device: ZHADevice | None = None
self._device_registry = collections.defaultdict(list) self._device_registry: collections.defaultdict[
EUI64, list[EntityReference]
] = collections.defaultdict(list)
self.zha_storage = None self.zha_storage = None
self.ha_device_registry = None self.ha_device_registry = None
self.ha_entity_registry = None self.ha_entity_registry = None
self.application_controller = None self.application_controller = None
self.radio_description = None self.radio_description = None
self._log_levels = { self._log_levels: dict[str, dict[str, int]] = {
DEBUG_LEVEL_ORIGINAL: async_capture_log_levels(), DEBUG_LEVEL_ORIGINAL: async_capture_log_levels(),
DEBUG_LEVEL_CURRENT: async_capture_log_levels(), DEBUG_LEVEL_CURRENT: async_capture_log_levels(),
} }
self.debug_enabled = False self.debug_enabled = False
self._log_relay_handler = LogRelayHandler(hass, self) self._log_relay_handler = LogRelayHandler(hass, self)
self.config_entry = config_entry self.config_entry = config_entry
self._unsubs = [] self._unsubs: list[Callable[[], None]] = []
async def async_initialize(self): async def async_initialize(self) -> None:
"""Initialize controller and connect radio.""" """Initialize controller and connect radio."""
discovery.PROBE.initialize(self._hass) discovery.PROBE.initialize(self._hass)
discovery.GROUP_PROBE.initialize(self._hass) discovery.GROUP_PROBE.initialize(self._hass)
@ -211,7 +228,7 @@ class ZHAGateway:
"""Initialize devices and load entities.""" """Initialize devices and load entities."""
semaphore = asyncio.Semaphore(2) semaphore = asyncio.Semaphore(2)
async def _throttle(zha_device: zha_typing.ZhaDeviceType, cached: bool): async def _throttle(zha_device: ZHADevice, cached: bool) -> None:
async with semaphore: async with semaphore:
await zha_device.async_initialize(from_cache=cached) await zha_device.async_initialize(from_cache=cached)
@ -233,7 +250,7 @@ class ZHAGateway:
) )
) )
def device_joined(self, device): def device_joined(self, device: zigpy.device.Device) -> None:
"""Handle device joined. """Handle device joined.
At this point, no information about the device is known other than its At this point, no information about the device is known other than its
@ -252,7 +269,7 @@ class ZHAGateway:
}, },
) )
def raw_device_initialized(self, device): def raw_device_initialized(self, device: zigpy.device.Device) -> None:
"""Handle a device initialization without quirks loaded.""" """Handle a device initialization without quirks loaded."""
manuf = device.manufacturer manuf = device.manufacturer
async_dispatcher_send( async_dispatcher_send(
@ -271,16 +288,16 @@ class ZHAGateway:
}, },
) )
def device_initialized(self, device): def device_initialized(self, device: zigpy.device.Device) -> None:
"""Handle device joined and basic information discovered.""" """Handle device joined and basic information discovered."""
self._hass.async_create_task(self.async_device_initialized(device)) self._hass.async_create_task(self.async_device_initialized(device))
def device_left(self, device: zigpy_dev.Device): def device_left(self, device: zigpy.device.Device) -> None:
"""Handle device leaving the network.""" """Handle device leaving the network."""
self.async_update_device(device, False) self.async_update_device(device, False)
def group_member_removed( def group_member_removed(
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint
) -> None: ) -> None:
"""Handle zigpy group member removed event.""" """Handle zigpy group member removed event."""
# need to handle endpoint correctly on groups # need to handle endpoint correctly on groups
@ -292,7 +309,7 @@ class ZHAGateway:
) )
def group_member_added( def group_member_added(
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint
) -> None: ) -> None:
"""Handle zigpy group member added event.""" """Handle zigpy group member added event."""
# need to handle endpoint correctly on groups # need to handle endpoint correctly on groups
@ -306,14 +323,14 @@ class ZHAGateway:
# we need to do this because there wasn't already a group entity to remove and re-add # we need to do this because there wasn't already a group entity to remove and re-add
discovery.GROUP_PROBE.discover_group_entities(zha_group) discovery.GROUP_PROBE.discover_group_entities(zha_group)
def group_added(self, zigpy_group: ZigpyGroupType) -> None: def group_added(self, zigpy_group: zigpy.group.Group) -> None:
"""Handle zigpy group added event.""" """Handle zigpy group added event."""
zha_group = self._async_get_or_create_group(zigpy_group) zha_group = self._async_get_or_create_group(zigpy_group)
zha_group.info("group_added") zha_group.info("group_added")
# need to dispatch for entity creation here # need to dispatch for entity creation here
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED)
def group_removed(self, zigpy_group: ZigpyGroupType) -> None: def group_removed(self, zigpy_group: zigpy.group.Group) -> None:
"""Handle zigpy group removed event.""" """Handle zigpy group removed event."""
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED)
zha_group = self._groups.pop(zigpy_group.group_id, None) zha_group = self._groups.pop(zigpy_group.group_id, None)
@ -321,7 +338,7 @@ class ZHAGateway:
self._cleanup_group_entity_registry_entries(zigpy_group) self._cleanup_group_entity_registry_entries(zigpy_group)
def _send_group_gateway_message( def _send_group_gateway_message(
self, zigpy_group: ZigpyGroupType, gateway_message_type: str self, zigpy_group: zigpy.group.Group, gateway_message_type: str
) -> None: ) -> None:
"""Send the gateway event for a zigpy group event.""" """Send the gateway event for a zigpy group event."""
zha_group = self._groups.get(zigpy_group.group_id) zha_group = self._groups.get(zigpy_group.group_id)
@ -335,7 +352,9 @@ class ZHAGateway:
}, },
) )
async def _async_remove_device(self, device, entity_refs): async def _async_remove_device(
self, device: ZHADevice, entity_refs: list[EntityReference] | None
) -> None:
if entity_refs is not None: if entity_refs is not None:
remove_tasks = [] remove_tasks = []
for entity_ref in entity_refs: for entity_ref in entity_refs:
@ -346,7 +365,7 @@ class ZHAGateway:
if reg_device is not None: if reg_device is not None:
self.ha_device_registry.async_remove_device(reg_device.id) self.ha_device_registry.async_remove_device(reg_device.id)
def device_removed(self, device): def device_removed(self, device: zigpy.device.Device) -> None:
"""Handle device being removed from the network.""" """Handle device being removed from the network."""
zha_device = self._devices.pop(device.ieee, None) zha_device = self._devices.pop(device.ieee, None)
entity_refs = self._device_registry.pop(device.ieee, None) entity_refs = self._device_registry.pop(device.ieee, None)
@ -365,23 +384,23 @@ class ZHAGateway:
}, },
) )
def get_device(self, ieee): def get_device(self, ieee: EUI64) -> ZHADevice | None:
"""Return ZHADevice for given ieee.""" """Return ZHADevice for given ieee."""
return self._devices.get(ieee) return self._devices.get(ieee)
def get_group(self, group_id: str) -> ZhaGroupType | None: def get_group(self, group_id: int) -> ZHAGroup | None:
"""Return Group for given group id.""" """Return Group for given group id."""
return self.groups.get(group_id) return self.groups.get(group_id)
@callback @callback
def async_get_group_by_name(self, group_name: str) -> ZhaGroupType | None: def async_get_group_by_name(self, group_name: str) -> ZHAGroup | None:
"""Get ZHA group by name.""" """Get ZHA group by name."""
for group in self.groups.values(): for group in self.groups.values():
if group.name == group_name: if group.name == group_name:
return group return group
return None return None
def get_entity_reference(self, entity_id): def get_entity_reference(self, entity_id: str) -> EntityReference | None:
"""Return entity reference for given entity_id if found.""" """Return entity reference for given entity_id if found."""
for entity_reference in itertools.chain.from_iterable( for entity_reference in itertools.chain.from_iterable(
self.device_registry.values() self.device_registry.values()
@ -389,7 +408,7 @@ class ZHAGateway:
if entity_id == entity_reference.reference_id: if entity_id == entity_reference.reference_id:
return entity_reference return entity_reference
def remove_entity_reference(self, entity): def remove_entity_reference(self, entity: ZhaEntity) -> None:
"""Remove entity reference for given entity_id if found.""" """Remove entity reference for given entity_id if found."""
if entity.zha_device.ieee in self.device_registry: if entity.zha_device.ieee in self.device_registry:
entity_refs = self.device_registry.get(entity.zha_device.ieee) entity_refs = self.device_registry.get(entity.zha_device.ieee)
@ -398,7 +417,7 @@ class ZHAGateway:
] ]
def _cleanup_group_entity_registry_entries( def _cleanup_group_entity_registry_entries(
self, zigpy_group: ZigpyGroupType self, zigpy_group: zigpy.group.Group
) -> None: ) -> None:
"""Remove entity registry entries for group entities when the groups are removed from HA.""" """Remove entity registry entries for group entities when the groups are removed from HA."""
# first we collect the potential unique ids for entities that could be created from this group # first we collect the potential unique ids for entities that could be created from this group
@ -429,17 +448,17 @@ class ZHAGateway:
self.ha_entity_registry.async_remove(entry.entity_id) self.ha_entity_registry.async_remove(entry.entity_id)
@property @property
def devices(self): def devices(self) -> dict[EUI64, ZHADevice]:
"""Return devices.""" """Return devices."""
return self._devices return self._devices
@property @property
def groups(self): def groups(self) -> dict[int, ZHAGroup]:
"""Return groups.""" """Return groups."""
return self._groups return self._groups
@property @property
def device_registry(self): def device_registry(self) -> collections.defaultdict[EUI64, list[EntityReference]]:
"""Return entities by ieee.""" """Return entities by ieee."""
return self._device_registry return self._device_registry
@ -464,7 +483,7 @@ class ZHAGateway:
) )
@callback @callback
def async_enable_debug_mode(self, filterer=None): def async_enable_debug_mode(self, filterer: _LogFilterType | None = None) -> None:
"""Enable debug mode for ZHA.""" """Enable debug mode for ZHA."""
self._log_levels[DEBUG_LEVEL_ORIGINAL] = async_capture_log_levels() self._log_levels[DEBUG_LEVEL_ORIGINAL] = async_capture_log_levels()
async_set_logger_levels(DEBUG_LEVELS) async_set_logger_levels(DEBUG_LEVELS)
@ -479,7 +498,7 @@ class ZHAGateway:
self.debug_enabled = True self.debug_enabled = True
@callback @callback
def async_disable_debug_mode(self, filterer=None): def async_disable_debug_mode(self, filterer: _LogFilterType | None = None) -> None:
"""Disable debug mode for ZHA.""" """Disable debug mode for ZHA."""
async_set_logger_levels(self._log_levels[DEBUG_LEVEL_ORIGINAL]) async_set_logger_levels(self._log_levels[DEBUG_LEVEL_ORIGINAL])
self._log_levels[DEBUG_LEVEL_CURRENT] = async_capture_log_levels() self._log_levels[DEBUG_LEVEL_CURRENT] = async_capture_log_levels()
@ -491,8 +510,8 @@ class ZHAGateway:
@callback @callback
def _async_get_or_create_device( def _async_get_or_create_device(
self, zigpy_device: zha_typing.ZigpyDeviceType, restored: bool = False self, zigpy_device: zigpy.device.Device, restored: bool = False
): ) -> ZHADevice:
"""Get or create a ZHA device.""" """Get or create a ZHA device."""
if (zha_device := self._devices.get(zigpy_device.ieee)) is None: if (zha_device := self._devices.get(zigpy_device.ieee)) is None:
zha_device = ZHADevice.new(self._hass, zigpy_device, self, restored) zha_device = ZHADevice.new(self._hass, zigpy_device, self, restored)
@ -511,7 +530,7 @@ class ZHAGateway:
return zha_device return zha_device
@callback @callback
def _async_get_or_create_group(self, zigpy_group: ZigpyGroupType) -> ZhaGroupType: def _async_get_or_create_group(self, zigpy_group: zigpy.group.Group) -> ZHAGroup:
"""Get or create a ZHA group.""" """Get or create a ZHA group."""
zha_group = self._groups.get(zigpy_group.group_id) zha_group = self._groups.get(zigpy_group.group_id)
if zha_group is None: if zha_group is None:
@ -521,7 +540,7 @@ class ZHAGateway:
@callback @callback
def async_update_device( def async_update_device(
self, sender: zigpy_dev.Device, available: bool = True self, sender: zigpy.device.Device, available: bool = True
) -> None: ) -> None:
"""Update device that has just become available.""" """Update device that has just become available."""
if sender.ieee in self.devices: if sender.ieee in self.devices:
@ -530,12 +549,12 @@ class ZHAGateway:
if device.status is DeviceStatus.INITIALIZED: if device.status is DeviceStatus.INITIALIZED:
device.update_available(available) device.update_available(available)
async def async_update_device_storage(self, *_): async def async_update_device_storage(self, *_: Any) -> None:
"""Update the devices in the store.""" """Update the devices in the store."""
for device in self.devices.values(): for device in self.devices.values():
self.zha_storage.async_update_device(device) self.zha_storage.async_update_device(device)
async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType): async def async_device_initialized(self, device: zigpy.device.Device) -> None:
"""Handle device joined and basic information discovered (async).""" """Handle device joined and basic information discovered (async)."""
zha_device = self._async_get_or_create_device(device) zha_device = self._async_get_or_create_device(device)
# This is an active device so set a last seen if it is none # This is an active device so set a last seen if it is none
@ -576,7 +595,7 @@ class ZHAGateway:
}, },
) )
async def _async_device_joined(self, zha_device: zha_typing.ZhaDeviceType) -> None: async def _async_device_joined(self, zha_device: ZHADevice) -> None:
zha_device.available = True zha_device.available = True
device_info = zha_device.device_info device_info = zha_device.device_info
await zha_device.async_configure() await zha_device.async_configure()
@ -592,7 +611,7 @@ class ZHAGateway:
await zha_device.async_initialize(from_cache=False) await zha_device.async_initialize(from_cache=False)
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
async def _async_device_rejoined(self, zha_device): async def _async_device_rejoined(self, zha_device: ZHADevice) -> None:
_LOGGER.debug( _LOGGER.debug(
"skipping discovery for previously discovered device - %s:%s", "skipping discovery for previously discovered device - %s:%s",
zha_device.nwk, zha_device.nwk,
@ -615,8 +634,11 @@ class ZHAGateway:
zha_device.update_available(True) zha_device.update_available(True)
async def async_create_zigpy_group( async def async_create_zigpy_group(
self, name: str, members: list[GroupMember], group_id: int = None self,
) -> ZhaGroupType: name: str,
members: list[GroupMember] | None,
group_id: int | None = None,
) -> ZHAGroup | None:
"""Create a new Zigpy Zigbee group.""" """Create a new Zigpy Zigbee group."""
# we start with two to fill any gaps from a user removing existing groups # we start with two to fill any gaps from a user removing existing groups
@ -659,7 +681,7 @@ class ZHAGateway:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
self.application_controller.groups.pop(group_id) self.application_controller.groups.pop(group_id)
async def shutdown(self): async def shutdown(self) -> None:
"""Stop ZHA Controller Application.""" """Stop ZHA Controller Application."""
_LOGGER.debug("Shutting down ZHA ControllerApplication") _LOGGER.debug("Shutting down ZHA ControllerApplication")
for unsubscribe in self._unsubs: for unsubscribe in self._unsubs:
@ -668,7 +690,7 @@ class ZHAGateway:
def handle_message( def handle_message(
self, self,
sender: zigpy_dev.Device, sender: zigpy.device.Device,
profile: int, profile: int,
cluster: int, cluster: int,
src_ep: int, src_ep: int,
@ -681,7 +703,7 @@ class ZHAGateway:
@callback @callback
def async_capture_log_levels(): def async_capture_log_levels() -> dict[str, int]:
"""Capture current logger levels for ZHA.""" """Capture current logger levels for ZHA."""
return { return {
DEBUG_COMP_BELLOWS: logging.getLogger(DEBUG_COMP_BELLOWS).getEffectiveLevel(), DEBUG_COMP_BELLOWS: logging.getLogger(DEBUG_COMP_BELLOWS).getEffectiveLevel(),
@ -703,7 +725,7 @@ def async_capture_log_levels():
@callback @callback
def async_set_logger_levels(levels): def async_set_logger_levels(levels: dict[str, int]) -> None:
"""Set logger levels for ZHA.""" """Set logger levels for ZHA."""
logging.getLogger(DEBUG_COMP_BELLOWS).setLevel(levels[DEBUG_COMP_BELLOWS]) logging.getLogger(DEBUG_COMP_BELLOWS).setLevel(levels[DEBUG_COMP_BELLOWS])
logging.getLogger(DEBUG_COMP_ZHA).setLevel(levels[DEBUG_COMP_ZHA]) logging.getLogger(DEBUG_COMP_ZHA).setLevel(levels[DEBUG_COMP_ZHA])
@ -717,13 +739,13 @@ def async_set_logger_levels(levels):
class LogRelayHandler(logging.Handler): class LogRelayHandler(logging.Handler):
"""Log handler for error messages.""" """Log handler for error messages."""
def __init__(self, hass, gateway): def __init__(self, hass: HomeAssistant, gateway: ZHAGateway) -> None:
"""Initialize a new LogErrorHandler.""" """Initialize a new LogErrorHandler."""
super().__init__() super().__init__()
self.hass = hass self.hass = hass
self.gateway = gateway self.gateway = gateway
def emit(self, record): def emit(self, record: LogRecord) -> None:
"""Relay log message via dispatcher.""" """Relay log message via dispatcher."""
stack = [] stack = []
if record.levelno >= logging.WARN and not record.exc_info: if record.levelno >= logging.WARN and not record.exc_info: