diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index 1d5656a1b8d..4f1e80e0a7b 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -109,8 +109,8 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b device_registry = await hass.helpers.device_registry.async_get_registry() device_registry.async_get_or_create( config_entry_id=config_entry.entry_id, - connections={(CONNECTION_ZIGBEE, str(zha_gateway.application_controller.ieee))}, - identifiers={(DOMAIN, str(zha_gateway.application_controller.ieee))}, + connections={(CONNECTION_ZIGBEE, str(zha_gateway.application_controller.ieee))}, # type: ignore[attr-defined] + identifiers={(DOMAIN, str(zha_gateway.application_controller.ieee))}, # type: ignore[attr-defined] name="Zigbee Coordinator", manufacturer="ZHA", model=zha_gateway.radio_description, diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 252893683ef..636e161d45c 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import collections +from collections.abc import Callable from datetime import timedelta from enum import Enum import itertools @@ -10,13 +11,18 @@ import logging import os import time import traceback +from typing import TYPE_CHECKING, Any, Union from serial import SerialException from zigpy.config import CONF_DEVICE -import zigpy.device as zigpy_dev +import zigpy.device +import zigpy.endpoint +import zigpy.group +from zigpy.types.named import EUI64 from homeassistant.components.system_log import LogEntry, _figure_out_source -from homeassistant.core import callback +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers.device_registry import ( CONNECTION_ZIGBEE, @@ -28,8 +34,9 @@ from homeassistant.helpers.entity_registry import ( async_get_registry as get_ent_reg, ) from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.typing import ConfigType -from . import discovery, typing as zha_typing +from . import discovery from .const import ( ATTR_IEEE, ATTR_MANUFACTURER, @@ -81,7 +88,13 @@ from .device import DeviceStatus, ZHADevice from .group import GroupMember, ZHAGroup from .registries import GROUP_ENTITY_DOMAINS from .store import async_get_registry -from .typing import ZhaGroupType, ZigpyEndpointType, ZigpyGroupType + +if TYPE_CHECKING: + from logging import Filter, LogRecord + + from ..entity import ZhaEntity + + _LogFilterType = Union[Filter, Callable[[LogRecord], int]] _LOGGER = logging.getLogger(__name__) @@ -103,29 +116,33 @@ class DevicePairingStatus(Enum): class ZHAGateway: """Gateway that handles events that happen on the ZHA Zigbee network.""" - def __init__(self, hass, config, config_entry): + def __init__( + self, hass: HomeAssistant, config: ConfigType, config_entry: ConfigEntry + ) -> None: """Initialize the gateway.""" self._hass = hass self._config = config - self._devices = {} - self._groups = {} - self.coordinator_zha_device = None - self._device_registry = collections.defaultdict(list) + self._devices: dict[EUI64, ZHADevice] = {} + self._groups: dict[int, ZHAGroup] = {} + self.coordinator_zha_device: ZHADevice | None = None + self._device_registry: collections.defaultdict[ + EUI64, list[EntityReference] + ] = collections.defaultdict(list) self.zha_storage = None self.ha_device_registry = None self.ha_entity_registry = None self.application_controller = None self.radio_description = None - self._log_levels = { + self._log_levels: dict[str, dict[str, int]] = { DEBUG_LEVEL_ORIGINAL: async_capture_log_levels(), DEBUG_LEVEL_CURRENT: async_capture_log_levels(), } self.debug_enabled = False self._log_relay_handler = LogRelayHandler(hass, self) self.config_entry = config_entry - self._unsubs = [] + self._unsubs: list[Callable[[], None]] = [] - async def async_initialize(self): + async def async_initialize(self) -> None: """Initialize controller and connect radio.""" discovery.PROBE.initialize(self._hass) discovery.GROUP_PROBE.initialize(self._hass) @@ -211,7 +228,7 @@ class ZHAGateway: """Initialize devices and load entities.""" semaphore = asyncio.Semaphore(2) - async def _throttle(zha_device: zha_typing.ZhaDeviceType, cached: bool): + async def _throttle(zha_device: ZHADevice, cached: bool) -> None: async with semaphore: await zha_device.async_initialize(from_cache=cached) @@ -233,7 +250,7 @@ class ZHAGateway: ) ) - def device_joined(self, device): + def device_joined(self, device: zigpy.device.Device) -> None: """Handle device joined. At this point, no information about the device is known other than its @@ -252,7 +269,7 @@ class ZHAGateway: }, ) - def raw_device_initialized(self, device): + def raw_device_initialized(self, device: zigpy.device.Device) -> None: """Handle a device initialization without quirks loaded.""" manuf = device.manufacturer async_dispatcher_send( @@ -271,16 +288,16 @@ class ZHAGateway: }, ) - def device_initialized(self, device): + def device_initialized(self, device: zigpy.device.Device) -> None: """Handle device joined and basic information discovered.""" self._hass.async_create_task(self.async_device_initialized(device)) - def device_left(self, device: zigpy_dev.Device): + def device_left(self, device: zigpy.device.Device) -> None: """Handle device leaving the network.""" self.async_update_device(device, False) def group_member_removed( - self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType + self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint ) -> None: """Handle zigpy group member removed event.""" # need to handle endpoint correctly on groups @@ -292,7 +309,7 @@ class ZHAGateway: ) def group_member_added( - self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType + self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint ) -> None: """Handle zigpy group member added event.""" # need to handle endpoint correctly on groups @@ -306,14 +323,14 @@ class ZHAGateway: # we need to do this because there wasn't already a group entity to remove and re-add discovery.GROUP_PROBE.discover_group_entities(zha_group) - def group_added(self, zigpy_group: ZigpyGroupType) -> None: + def group_added(self, zigpy_group: zigpy.group.Group) -> None: """Handle zigpy group added event.""" zha_group = self._async_get_or_create_group(zigpy_group) zha_group.info("group_added") # need to dispatch for entity creation here self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED) - def group_removed(self, zigpy_group: ZigpyGroupType) -> None: + def group_removed(self, zigpy_group: zigpy.group.Group) -> None: """Handle zigpy group removed event.""" self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED) zha_group = self._groups.pop(zigpy_group.group_id, None) @@ -321,7 +338,7 @@ class ZHAGateway: self._cleanup_group_entity_registry_entries(zigpy_group) def _send_group_gateway_message( - self, zigpy_group: ZigpyGroupType, gateway_message_type: str + self, zigpy_group: zigpy.group.Group, gateway_message_type: str ) -> None: """Send the gateway event for a zigpy group event.""" zha_group = self._groups.get(zigpy_group.group_id) @@ -335,7 +352,9 @@ class ZHAGateway: }, ) - async def _async_remove_device(self, device, entity_refs): + async def _async_remove_device( + self, device: ZHADevice, entity_refs: list[EntityReference] | None + ) -> None: if entity_refs is not None: remove_tasks = [] for entity_ref in entity_refs: @@ -346,7 +365,7 @@ class ZHAGateway: if reg_device is not None: self.ha_device_registry.async_remove_device(reg_device.id) - def device_removed(self, device): + def device_removed(self, device: zigpy.device.Device) -> None: """Handle device being removed from the network.""" zha_device = self._devices.pop(device.ieee, None) entity_refs = self._device_registry.pop(device.ieee, None) @@ -365,23 +384,23 @@ class ZHAGateway: }, ) - def get_device(self, ieee): + def get_device(self, ieee: EUI64) -> ZHADevice | None: """Return ZHADevice for given ieee.""" return self._devices.get(ieee) - def get_group(self, group_id: str) -> ZhaGroupType | None: + def get_group(self, group_id: int) -> ZHAGroup | None: """Return Group for given group id.""" return self.groups.get(group_id) @callback - def async_get_group_by_name(self, group_name: str) -> ZhaGroupType | None: + def async_get_group_by_name(self, group_name: str) -> ZHAGroup | None: """Get ZHA group by name.""" for group in self.groups.values(): if group.name == group_name: return group return None - def get_entity_reference(self, entity_id): + def get_entity_reference(self, entity_id: str) -> EntityReference | None: """Return entity reference for given entity_id if found.""" for entity_reference in itertools.chain.from_iterable( self.device_registry.values() @@ -389,7 +408,7 @@ class ZHAGateway: if entity_id == entity_reference.reference_id: return entity_reference - def remove_entity_reference(self, entity): + def remove_entity_reference(self, entity: ZhaEntity) -> None: """Remove entity reference for given entity_id if found.""" if entity.zha_device.ieee in self.device_registry: entity_refs = self.device_registry.get(entity.zha_device.ieee) @@ -398,7 +417,7 @@ class ZHAGateway: ] def _cleanup_group_entity_registry_entries( - self, zigpy_group: ZigpyGroupType + self, zigpy_group: zigpy.group.Group ) -> None: """Remove entity registry entries for group entities when the groups are removed from HA.""" # first we collect the potential unique ids for entities that could be created from this group @@ -429,17 +448,17 @@ class ZHAGateway: self.ha_entity_registry.async_remove(entry.entity_id) @property - def devices(self): + def devices(self) -> dict[EUI64, ZHADevice]: """Return devices.""" return self._devices @property - def groups(self): + def groups(self) -> dict[int, ZHAGroup]: """Return groups.""" return self._groups @property - def device_registry(self): + def device_registry(self) -> collections.defaultdict[EUI64, list[EntityReference]]: """Return entities by ieee.""" return self._device_registry @@ -464,7 +483,7 @@ class ZHAGateway: ) @callback - def async_enable_debug_mode(self, filterer=None): + def async_enable_debug_mode(self, filterer: _LogFilterType | None = None) -> None: """Enable debug mode for ZHA.""" self._log_levels[DEBUG_LEVEL_ORIGINAL] = async_capture_log_levels() async_set_logger_levels(DEBUG_LEVELS) @@ -479,7 +498,7 @@ class ZHAGateway: self.debug_enabled = True @callback - def async_disable_debug_mode(self, filterer=None): + def async_disable_debug_mode(self, filterer: _LogFilterType | None = None) -> None: """Disable debug mode for ZHA.""" async_set_logger_levels(self._log_levels[DEBUG_LEVEL_ORIGINAL]) self._log_levels[DEBUG_LEVEL_CURRENT] = async_capture_log_levels() @@ -491,8 +510,8 @@ class ZHAGateway: @callback def _async_get_or_create_device( - self, zigpy_device: zha_typing.ZigpyDeviceType, restored: bool = False - ): + self, zigpy_device: zigpy.device.Device, restored: bool = False + ) -> ZHADevice: """Get or create a ZHA device.""" if (zha_device := self._devices.get(zigpy_device.ieee)) is None: zha_device = ZHADevice.new(self._hass, zigpy_device, self, restored) @@ -511,7 +530,7 @@ class ZHAGateway: return zha_device @callback - def _async_get_or_create_group(self, zigpy_group: ZigpyGroupType) -> ZhaGroupType: + def _async_get_or_create_group(self, zigpy_group: zigpy.group.Group) -> ZHAGroup: """Get or create a ZHA group.""" zha_group = self._groups.get(zigpy_group.group_id) if zha_group is None: @@ -521,7 +540,7 @@ class ZHAGateway: @callback def async_update_device( - self, sender: zigpy_dev.Device, available: bool = True + self, sender: zigpy.device.Device, available: bool = True ) -> None: """Update device that has just become available.""" if sender.ieee in self.devices: @@ -530,12 +549,12 @@ class ZHAGateway: if device.status is DeviceStatus.INITIALIZED: device.update_available(available) - async def async_update_device_storage(self, *_): + async def async_update_device_storage(self, *_: Any) -> None: """Update the devices in the store.""" for device in self.devices.values(): self.zha_storage.async_update_device(device) - async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType): + async def async_device_initialized(self, device: zigpy.device.Device) -> None: """Handle device joined and basic information discovered (async).""" zha_device = self._async_get_or_create_device(device) # This is an active device so set a last seen if it is none @@ -576,7 +595,7 @@ class ZHAGateway: }, ) - async def _async_device_joined(self, zha_device: zha_typing.ZhaDeviceType) -> None: + async def _async_device_joined(self, zha_device: ZHADevice) -> None: zha_device.available = True device_info = zha_device.device_info await zha_device.async_configure() @@ -592,7 +611,7 @@ class ZHAGateway: await zha_device.async_initialize(from_cache=False) async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) - async def _async_device_rejoined(self, zha_device): + async def _async_device_rejoined(self, zha_device: ZHADevice) -> None: _LOGGER.debug( "skipping discovery for previously discovered device - %s:%s", zha_device.nwk, @@ -615,8 +634,11 @@ class ZHAGateway: zha_device.update_available(True) async def async_create_zigpy_group( - self, name: str, members: list[GroupMember], group_id: int = None - ) -> ZhaGroupType: + self, + name: str, + members: list[GroupMember] | None, + group_id: int | None = None, + ) -> ZHAGroup | None: """Create a new Zigpy Zigbee group.""" # we start with two to fill any gaps from a user removing existing groups @@ -659,7 +681,7 @@ class ZHAGateway: await asyncio.gather(*tasks) self.application_controller.groups.pop(group_id) - async def shutdown(self): + async def shutdown(self) -> None: """Stop ZHA Controller Application.""" _LOGGER.debug("Shutting down ZHA ControllerApplication") for unsubscribe in self._unsubs: @@ -668,7 +690,7 @@ class ZHAGateway: def handle_message( self, - sender: zigpy_dev.Device, + sender: zigpy.device.Device, profile: int, cluster: int, src_ep: int, @@ -681,7 +703,7 @@ class ZHAGateway: @callback -def async_capture_log_levels(): +def async_capture_log_levels() -> dict[str, int]: """Capture current logger levels for ZHA.""" return { DEBUG_COMP_BELLOWS: logging.getLogger(DEBUG_COMP_BELLOWS).getEffectiveLevel(), @@ -703,7 +725,7 @@ def async_capture_log_levels(): @callback -def async_set_logger_levels(levels): +def async_set_logger_levels(levels: dict[str, int]) -> None: """Set logger levels for ZHA.""" logging.getLogger(DEBUG_COMP_BELLOWS).setLevel(levels[DEBUG_COMP_BELLOWS]) logging.getLogger(DEBUG_COMP_ZHA).setLevel(levels[DEBUG_COMP_ZHA]) @@ -717,13 +739,13 @@ def async_set_logger_levels(levels): class LogRelayHandler(logging.Handler): """Log handler for error messages.""" - def __init__(self, hass, gateway): + def __init__(self, hass: HomeAssistant, gateway: ZHAGateway) -> None: """Initialize a new LogErrorHandler.""" super().__init__() self.hass = hass self.gateway = gateway - def emit(self, record): + def emit(self, record: LogRecord) -> None: """Relay log message via dispatcher.""" stack = [] if record.levelno >= logging.WARN and not record.exc_info: