From 0bae0824b42ee4f5e18cff36ce3b3ddc99487bc0 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 31 Aug 2023 12:09:46 -0400 Subject: [PATCH] Initialize ZHA device database before connecting to the radio (#98082) * Create ZHA entities before attempting to connect to the coordinator * Delete the ZHA gateway object when unloading the config entry * Only load ZHA groups if the coordinator device info is known offline * Do not create a coordinator ZHA device until it is ready * [WIP] begin fixing unit tests * [WIP] Fix existing unit tests (one failure left) * Fix remaining unit test --- homeassistant/components/zha/__init__.py | 20 +---- homeassistant/components/zha/core/const.py | 1 - homeassistant/components/zha/core/gateway.py | 53 ++++++++++--- homeassistant/components/zha/core/helpers.py | 6 +- tests/components/zha/common.py | 10 +-- tests/components/zha/conftest.py | 26 ++++++- tests/components/zha/test_api.py | 2 +- tests/components/zha/test_gateway.py | 79 +++++++++++++------- tests/components/zha/test_websocket_api.py | 11 ++- 9 files changed, 133 insertions(+), 75 deletions(-) diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index a51d6f387e1..e48f8ce2096 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -10,7 +10,7 @@ from zhaquirks import setup as setup_quirks from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_TYPE, EVENT_HOMEASSISTANT_STOP +from homeassistant.const import CONF_TYPE from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr import homeassistant.helpers.config_validation as cv @@ -33,7 +33,6 @@ from .core.const import ( DATA_ZHA, DATA_ZHA_CONFIG, DATA_ZHA_GATEWAY, - DATA_ZHA_SHUTDOWN_TASK, DOMAIN, PLATFORMS, SIGNAL_ADD_ENTITIES, @@ -137,6 +136,8 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b zha_gateway = ZHAGateway(hass, config, config_entry) await zha_gateway.async_initialize() + config_entry.async_on_unload(zha_gateway.shutdown) + device_registry = dr.async_get(hass) device_registry.async_get_or_create( config_entry_id=config_entry.entry_id, @@ -149,15 +150,6 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b websocket_api.async_load_api(hass) - async def async_zha_shutdown(event): - """Handle shutdown tasks.""" - zha_gateway: ZHAGateway = zha_data[DATA_ZHA_GATEWAY] - await zha_gateway.shutdown() - - zha_data[DATA_ZHA_SHUTDOWN_TASK] = hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_STOP, async_zha_shutdown - ) - await zha_gateway.async_initialize_devices_and_entities() await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS) async_dispatcher_send(hass, SIGNAL_ADD_ENTITIES) @@ -167,12 +159,10 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: """Unload ZHA config entry.""" try: - zha_gateway: ZHAGateway = hass.data[DATA_ZHA].pop(DATA_ZHA_GATEWAY) + del hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] except KeyError: return False - await zha_gateway.shutdown() - GROUP_PROBE.cleanup() websocket_api.async_unload_api(hass) @@ -184,8 +174,6 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> ) ) - hass.data[DATA_ZHA][DATA_ZHA_SHUTDOWN_TASK]() - return True diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index 7aab6112ab0..63b59e9d8d4 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -187,7 +187,6 @@ DATA_ZHA_CONFIG = "config" DATA_ZHA_BRIDGE_ID = "zha_bridge_id" DATA_ZHA_CORE_EVENTS = "zha_core_events" DATA_ZHA_GATEWAY = "zha_gateway" -DATA_ZHA_SHUTDOWN_TASK = "zha_shutdown_task" DEBUG_COMP_BELLOWS = "bellows" DEBUG_COMP_ZHA = "homeassistant.components.zha" diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 1320e77ba3c..3abf1274f98 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -148,7 +148,6 @@ class ZHAGateway: self._log_relay_handler = LogRelayHandler(hass, self) self.config_entry = config_entry self._unsubs: list[Callable[[], None]] = [] - self.initialized: bool = False def get_application_controller_data(self) -> tuple[ControllerApplication, dict]: """Get an uninitialized instance of a zigpy `ControllerApplication`.""" @@ -199,12 +198,32 @@ class ZHAGateway: self.ha_entity_registry = er.async_get(self._hass) app_controller_cls, app_config = self.get_application_controller_data() + self.application_controller = await app_controller_cls.new( + config=app_config, + auto_form=False, + start_radio=False, + ) + + self._hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] = self + + self.async_load_devices() + + # Groups are attached to the coordinator device so we need to load it early + coordinator = self._find_coordinator_device() + loaded_groups = False + + # We can only load groups early if the coordinator's model info has been stored + # in the zigpy database + if coordinator.model is not None: + self.coordinator_zha_device = self._async_get_or_create_device( + coordinator, restored=True + ) + self.async_load_groups() + loaded_groups = True for attempt in range(STARTUP_RETRIES): try: - self.application_controller = await app_controller_cls.new( - app_config, auto_form=True, start_radio=True - ) + await self.application_controller.startup(auto_form=True) except zigpy.exceptions.TransientConnectionError as exc: raise ConfigEntryNotReady from exc except Exception as exc: # pylint: disable=broad-except @@ -223,21 +242,33 @@ class ZHAGateway: else: break + self.coordinator_zha_device = self._async_get_or_create_device( + self._find_coordinator_device(), restored=True + ) + self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str(self.coordinator_ieee) + + # If ZHA groups could not load early, we can safely load them now + if not loaded_groups: + self.async_load_groups() + self.application_controller.add_listener(self) self.application_controller.groups.add_listener(self) - self._hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] = self - self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str(self.coordinator_ieee) - self.async_load_devices() - self.async_load_groups() - self.initialized = True + + def _find_coordinator_device(self) -> zigpy.device.Device: + if last_backup := self.application_controller.backups.most_recent_backup(): + zigpy_coordinator = self.application_controller.get_device( + ieee=last_backup.node_info.ieee + ) + else: + zigpy_coordinator = self.application_controller.get_device(nwk=0x0000) + + return zigpy_coordinator @callback def async_load_devices(self) -> None: """Restore ZHA devices from zigpy application state.""" for zigpy_device in self.application_controller.devices.values(): zha_device = self._async_get_or_create_device(zigpy_device, restored=True) - if zha_device.ieee == self.coordinator_ieee: - self.coordinator_zha_device = zha_device delta_msg = "not known" if zha_device.last_seen is not None: delta = round(time.time() - zha_device.last_seen) diff --git a/homeassistant/components/zha/core/helpers.py b/homeassistant/components/zha/core/helpers.py index ac7c15d3ecd..7b0d062738b 100644 --- a/homeassistant/components/zha/core/helpers.py +++ b/homeassistant/components/zha/core/helpers.py @@ -27,7 +27,6 @@ import zigpy.zdo.types as zdo_types from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, State, callback -from homeassistant.exceptions import IntegrationError from homeassistant.helpers import config_validation as cv, device_registry as dr from .const import ( @@ -246,11 +245,8 @@ def async_get_zha_device(hass: HomeAssistant, device_id: str) -> ZHADevice: _LOGGER.error("Device id `%s` not found in registry", device_id) raise KeyError(f"Device id `{device_id}` not found in registry.") zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - if not zha_gateway.initialized: - _LOGGER.error("Attempting to get a ZHA device when ZHA is not initialized") - raise IntegrationError("ZHA is not initialized yet") try: - ieee_address = list(list(registry_device.identifiers)[0])[1] + ieee_address = list(registry_device.identifiers)[0][1] ieee = zigpy.types.EUI64.convert(ieee_address) except (IndexError, ValueError) as ex: _LOGGER.error( diff --git a/tests/components/zha/common.py b/tests/components/zha/common.py index 01206c432e6..db1da3721ee 100644 --- a/tests/components/zha/common.py +++ b/tests/components/zha/common.py @@ -87,10 +87,7 @@ def update_attribute_cache(cluster): def get_zha_gateway(hass): """Return ZHA gateway from hass.data.""" - try: - return hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] - except KeyError: - return None + return hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] def make_attribute(attrid, value, status=0): @@ -167,12 +164,9 @@ def find_entity_ids(domain, zha_device, hass): def async_find_group_entity_id(hass, domain, group): """Find the group entity id under test.""" - entity_id = ( - f"{domain}.fakemanufacturer_fakemodel_{group.name.lower().replace(' ', '_')}" - ) + entity_id = f"{domain}.coordinator_manufacturer_coordinator_model_{group.name.lower().replace(' ', '_')}" entity_ids = hass.states.async_entity_ids(domain) - assert entity_id in entity_ids return entity_id diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index dd2c200973c..f690a5152fc 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -16,6 +16,8 @@ import zigpy.profiles import zigpy.quirks import zigpy.types import zigpy.util +from zigpy.zcl.clusters.general import Basic, Groups +from zigpy.zcl.foundation import Status import zigpy.zdo.types as zdo_t import homeassistant.components.zha.core.const as zha_const @@ -116,6 +118,9 @@ def zigpy_app_controller(): { zigpy.config.CONF_DATABASE: None, zigpy.config.CONF_DEVICE: {zigpy.config.CONF_DEVICE_PATH: "/dev/null"}, + zigpy.config.CONF_STARTUP_ENERGY_SCAN: False, + zigpy.config.CONF_NWK_BACKUP_ENABLED: False, + zigpy.config.CONF_TOPO_SCAN_ENABLED: False, } ) @@ -128,9 +133,24 @@ def zigpy_app_controller(): app.state.network_info.channel = 15 app.state.network_info.network_key.key = zigpy.types.KeyData(range(16)) - with patch("zigpy.device.Device.request"), patch.object( - app, "permit", autospec=True - ), patch.object(app, "permit_with_key", autospec=True): + # Create a fake coordinator device + dev = app.add_device(nwk=app.state.node_info.nwk, ieee=app.state.node_info.ieee) + dev.node_desc = zdo_t.NodeDescriptor() + dev.node_desc.logical_type = zdo_t.LogicalType.Coordinator + dev.manufacturer = "Coordinator Manufacturer" + dev.model = "Coordinator Model" + + ep = dev.add_endpoint(1) + ep.add_input_cluster(Basic.cluster_id) + ep.add_input_cluster(Groups.cluster_id) + + with patch( + "zigpy.device.Device.request", return_value=[Status.SUCCESS] + ), patch.object(app, "permit", autospec=True), patch.object( + app, "startup", wraps=app.startup + ), patch.object( + app, "permit_with_key", autospec=True + ): yield app diff --git a/tests/components/zha/test_api.py b/tests/components/zha/test_api.py index 85f85cc0437..c2cb16efcc8 100644 --- a/tests/components/zha/test_api.py +++ b/tests/components/zha/test_api.py @@ -71,7 +71,7 @@ async def test_async_get_network_settings_missing( await setup_zha() gateway = api._get_gateway(hass) - await zha.async_unload_entry(hass, gateway.config_entry) + await gateway.config_entry.async_unload(hass) # Network settings were never loaded for whatever reason zigpy_app_controller.state.network_info = zigpy.state.NetworkInfo() diff --git a/tests/components/zha/test_gateway.py b/tests/components/zha/test_gateway.py index b9fcd4b6932..0f791a08955 100644 --- a/tests/components/zha/test_gateway.py +++ b/tests/components/zha/test_gateway.py @@ -1,9 +1,9 @@ """Test ZHA Gateway.""" import asyncio -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest +from zigpy.application import ControllerApplication import zigpy.exceptions import zigpy.profiles.zha as zha import zigpy.zcl.clusters.general as general @@ -232,68 +232,89 @@ async def test_gateway_create_group_with_id( ) @patch("homeassistant.components.zha.core.gateway.STARTUP_FAILURE_DELAY_S", 0.01) @pytest.mark.parametrize( - "startup", + "startup_effect", [ - [asyncio.TimeoutError(), FileNotFoundError(), MagicMock()], - [asyncio.TimeoutError(), MagicMock()], - [MagicMock()], + [asyncio.TimeoutError(), FileNotFoundError(), None], + [asyncio.TimeoutError(), None], + [None], ], ) async def test_gateway_initialize_success( - startup: list[Any], + startup_effect: list[Exception | None], hass: HomeAssistant, device_light_1: ZHADevice, coordinator: ZHADevice, + zigpy_app_controller: ControllerApplication, ) -> None: """Test ZHA initializing the gateway successfully.""" zha_gateway = get_zha_gateway(hass) assert zha_gateway is not None - zha_gateway.shutdown = AsyncMock() + zigpy_app_controller.startup.side_effect = startup_effect + zigpy_app_controller.startup.reset_mock() with patch( - "bellows.zigbee.application.ControllerApplication.new", side_effect=startup - ) as mock_new: + "bellows.zigbee.application.ControllerApplication.new", + return_value=zigpy_app_controller, + ): await zha_gateway.async_initialize() - assert mock_new.call_count == len(startup) - + assert zigpy_app_controller.startup.call_count == len(startup_effect) device_light_1.async_cleanup_handles() @patch("homeassistant.components.zha.core.gateway.STARTUP_FAILURE_DELAY_S", 0.01) async def test_gateway_initialize_failure( - hass: HomeAssistant, device_light_1, coordinator + hass: HomeAssistant, + device_light_1: ZHADevice, + coordinator: ZHADevice, + zigpy_app_controller: ControllerApplication, ) -> None: """Test ZHA failing to initialize the gateway.""" zha_gateway = get_zha_gateway(hass) assert zha_gateway is not None + zigpy_app_controller.startup.side_effect = [ + asyncio.TimeoutError(), + RuntimeError(), + FileNotFoundError(), + ] + zigpy_app_controller.startup.reset_mock() + with patch( "bellows.zigbee.application.ControllerApplication.new", - side_effect=[asyncio.TimeoutError(), FileNotFoundError(), RuntimeError()], - ) as mock_new, pytest.raises(RuntimeError): + return_value=zigpy_app_controller, + ), pytest.raises(FileNotFoundError): await zha_gateway.async_initialize() - assert mock_new.call_count == 3 + assert zigpy_app_controller.startup.call_count == 3 @patch("homeassistant.components.zha.core.gateway.STARTUP_FAILURE_DELAY_S", 0.01) async def test_gateway_initialize_failure_transient( - hass: HomeAssistant, device_light_1, coordinator + hass: HomeAssistant, + device_light_1: ZHADevice, + coordinator: ZHADevice, + zigpy_app_controller: ControllerApplication, ) -> None: """Test ZHA failing to initialize the gateway but with a transient error.""" zha_gateway = get_zha_gateway(hass) assert zha_gateway is not None + zigpy_app_controller.startup.side_effect = [ + RuntimeError(), + zigpy.exceptions.TransientConnectionError(), + ] + zigpy_app_controller.startup.reset_mock() + with patch( "bellows.zigbee.application.ControllerApplication.new", - side_effect=[RuntimeError(), zigpy.exceptions.TransientConnectionError()], - ) as mock_new, pytest.raises(ConfigEntryNotReady): + return_value=zigpy_app_controller, + ), pytest.raises(ConfigEntryNotReady): await zha_gateway.async_initialize() # Initialization immediately stops and is retried after TransientConnectionError - assert mock_new.call_count == 2 + assert zigpy_app_controller.startup.call_count == 2 @patch( @@ -313,7 +334,12 @@ async def test_gateway_initialize_failure_transient( ], ) async def test_gateway_initialize_bellows_thread( - device_path, thread_state, config_override, hass: HomeAssistant, coordinator + device_path: str, + thread_state: bool, + config_override: dict, + hass: HomeAssistant, + coordinator: ZHADevice, + zigpy_app_controller: ControllerApplication, ) -> None: """Test ZHA disabling the UART thread when connecting to a TCP coordinator.""" zha_gateway = get_zha_gateway(hass) @@ -324,15 +350,12 @@ async def test_gateway_initialize_bellows_thread( zha_gateway._config.setdefault("zigpy_config", {}).update(config_override) with patch( - "bellows.zigbee.application.ControllerApplication.new" - ) as controller_app_mock: - mock = AsyncMock() - mock.add_listener = MagicMock() - mock.groups = MagicMock() - controller_app_mock.return_value = mock + "bellows.zigbee.application.ControllerApplication.new", + return_value=zigpy_app_controller, + ) as mock_new: await zha_gateway.async_initialize() - assert controller_app_mock.mock_calls[0].args[0]["use_thread"] is thread_state + assert mock_new.mock_calls[0].kwargs["config"]["use_thread"] is thread_state @pytest.mark.parametrize( diff --git a/tests/components/zha/test_websocket_api.py b/tests/components/zha/test_websocket_api.py index 0904fc1f685..740ffd6c06c 100644 --- a/tests/components/zha/test_websocket_api.py +++ b/tests/components/zha/test_websocket_api.py @@ -13,6 +13,7 @@ import zigpy.profiles.zha import zigpy.types from zigpy.types.named import EUI64 import zigpy.zcl.clusters.general as general +from zigpy.zcl.clusters.general import Groups import zigpy.zcl.clusters.security as security import zigpy.zdo.types as zdo_types @@ -233,7 +234,7 @@ async def test_list_devices(zha_client) -> None: msg = await zha_client.receive_json() devices = msg["result"] - assert len(devices) == 2 + assert len(devices) == 2 + 1 # the coordinator is included as well msg_id = 100 for device in devices: @@ -371,8 +372,13 @@ async def test_get_group_not_found(zha_client) -> None: assert msg["error"]["code"] == const.ERR_NOT_FOUND -async def test_list_groupable_devices(zha_client, device_groupable) -> None: +async def test_list_groupable_devices( + zha_client, device_groupable, zigpy_app_controller +) -> None: """Test getting ZHA devices that have a group cluster.""" + # Ensure the coordinator doesn't have a group cluster + coordinator = zigpy_app_controller.get_device(nwk=0x0000) + del coordinator.endpoints[1].in_clusters[Groups.cluster_id] await zha_client.send_json({ID: 10, TYPE: "zha/devices/groupable"}) @@ -479,6 +485,7 @@ async def app_controller( ) -> ControllerApplication: """Fixture for zigpy Application Controller.""" await setup_zha() + zigpy_app_controller.permit.reset_mock() return zigpy_app_controller