From e0795e6d07fd9a29d58d3a7233fe34a742429528 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 17 Feb 2025 18:16:57 +0100 Subject: [PATCH] Improve config entry state transitions when unloading and removing entries (#138522) * Improve config entry state transitions when unloading and removing entries * Update integrations which check for a single loaded entry * Update tests checking state after unload fails * Update homeassistant/config_entries.py Co-authored-by: Martin Hjelmare --------- Co-authored-by: Martin Hjelmare --- homeassistant/components/adguard/__init__.py | 9 ++------ .../google_assistant_sdk/__init__.py | 9 ++------ .../components/google_mail/__init__.py | 9 ++------ .../components/google_sheets/__init__.py | 9 ++------ homeassistant/components/guardian/__init__.py | 9 ++------ homeassistant/components/lookin/__init__.py | 9 ++------ .../components/motion_blinds/__init__.py | 9 ++------ .../components/netgear_lte/__init__.py | 8 +------ .../components/rainmachine/__init__.py | 9 ++------ .../components/simplisafe/__init__.py | 9 ++------ .../components/tplink_omada/__init__.py | 9 ++------ .../components/xiaomi_aqara/__init__.py | 9 ++------ homeassistant/config_entries.py | 23 +++++++++++++------ tests/components/matter/test_init.py | 2 +- tests/components/unifi/test_hub.py | 2 +- tests/components/zwave_js/test_init.py | 2 +- tests/test_config_entries.py | 8 +++---- 17 files changed, 46 insertions(+), 98 deletions(-) diff --git a/homeassistant/components/adguard/__init__.py b/homeassistant/components/adguard/__init__.py index f8ddeba6767..bbc763d7ec3 100644 --- a/homeassistant/components/adguard/__init__.py +++ b/homeassistant/components/adguard/__init__.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from adguardhome import AdGuardHome, AdGuardHomeConnectionError import voluptuous as vol -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_HOST, CONF_NAME, @@ -123,12 +123,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> b async def async_unload_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> bool: """Unload AdGuard Home config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): # This is the last loaded instance of AdGuard, deregister any services hass.services.async_remove(DOMAIN, SERVICE_ADD_URL) hass.services.async_remove(DOMAIN, SERVICE_REMOVE_URL) diff --git a/homeassistant/components/google_assistant_sdk/__init__.py b/homeassistant/components/google_assistant_sdk/__init__.py index 4ea496f2824..a08d7554516 100644 --- a/homeassistant/components/google_assistant_sdk/__init__.py +++ b/homeassistant/components/google_assistant_sdk/__init__.py @@ -10,7 +10,7 @@ from google.oauth2.credentials import Credentials import voluptuous as vol from homeassistant.components import conversation -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform from homeassistant.core import ( HomeAssistant, @@ -99,12 +99,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" hass.data[DOMAIN].pop(entry.entry_id) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): for service_name in hass.services.async_services_for_domain(DOMAIN): hass.services.async_remove(DOMAIN, service_name) diff --git a/homeassistant/components/google_mail/__init__.py b/homeassistant/components/google_mail/__init__.py index 7fae5f18da5..8ef978568dc 100644 --- a/homeassistant/components/google_mail/__init__.py +++ b/homeassistant/components/google_mail/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_NAME, Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv, discovery @@ -59,12 +59,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: GoogleMailConfigEntry) - async def async_unload_entry(hass: HomeAssistant, entry: GoogleMailConfigEntry) -> bool: """Unload a config entry.""" - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): for service_name in hass.services.async_services_for_domain(DOMAIN): hass.services.async_remove(DOMAIN, service_name) diff --git a/homeassistant/components/google_sheets/__init__.py b/homeassistant/components/google_sheets/__init__.py index faf1ff1ee0b..afafce816a9 100644 --- a/homeassistant/components/google_sheets/__init__.py +++ b/homeassistant/components/google_sheets/__init__.py @@ -12,7 +12,7 @@ from gspread.exceptions import APIError from gspread.utils import ValueInputOption import voluptuous as vol -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.exceptions import ( @@ -81,12 +81,7 @@ async def async_unload_entry( hass: HomeAssistant, entry: GoogleSheetsConfigEntry ) -> bool: """Unload a config entry.""" - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): for service_name in hass.services.async_services_for_domain(DOMAIN): hass.services.async_remove(DOMAIN, service_name) diff --git a/homeassistant/components/guardian/__init__.py b/homeassistant/components/guardian/__init__.py index c1cbb4c0e5a..075c388c4e4 100644 --- a/homeassistant/components/guardian/__init__.py +++ b/homeassistant/components/guardian/__init__.py @@ -11,7 +11,7 @@ from aioguardian import Client from aioguardian.errors import GuardianError import voluptuous as vol -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( ATTR_DEVICE_ID, CONF_DEVICE_ID, @@ -247,12 +247,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if unload_ok: hass.data[DOMAIN].pop(entry.entry_id) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): # If this is the last loaded instance of Guardian, deregister any services # defined during integration setup: for service_name in SERVICES: diff --git a/homeassistant/components/lookin/__init__.py b/homeassistant/components/lookin/__init__.py index 2fbabc12747..247282309e4 100644 --- a/homeassistant/components/lookin/__init__.py +++ b/homeassistant/components/lookin/__init__.py @@ -19,7 +19,7 @@ from aiolookin import ( ) from aiolookin.models import UDPCommandType, UDPEvent -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_HOST, Platform from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import ConfigEntryNotReady @@ -192,12 +192,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): hass.data[DOMAIN].pop(entry.entry_id) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): manager: LookinUDPManager = hass.data[DOMAIN][UDP_MANAGER] await manager.async_stop() return unload_ok diff --git a/homeassistant/components/motion_blinds/__init__.py b/homeassistant/components/motion_blinds/__init__.py index fa1664353e1..df06ffb75fc 100644 --- a/homeassistant/components/motion_blinds/__init__.py +++ b/homeassistant/components/motion_blinds/__init__.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from motionblinds import AsyncMotionMulticast -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_API_KEY, CONF_HOST, EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady @@ -124,12 +124,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> multicast.Unregister_motion_gateway(config_entry.data[CONF_HOST]) hass.data[DOMAIN].pop(config_entry.entry_id) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): # No motion gateways left, stop Motion multicast unsub_stop = hass.data[DOMAIN].pop(KEY_UNSUB_STOP) unsub_stop() diff --git a/homeassistant/components/netgear_lte/__init__.py b/homeassistant/components/netgear_lte/__init__.py index a756d85c866..47a39a39be0 100644 --- a/homeassistant/components/netgear_lte/__init__.py +++ b/homeassistant/components/netgear_lte/__init__.py @@ -6,7 +6,6 @@ from aiohttp.cookiejar import CookieJar import eternalegypt from eternalegypt.eternalegypt import SMS -from homeassistant.config_entries import ConfigEntryState from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady @@ -117,12 +116,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: NetgearLTEConfigEntry) - async def async_unload_entry(hass: HomeAssistant, entry: NetgearLTEConfigEntry) -> bool: """Unload a config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): hass.data.pop(DOMAIN, None) for service_name in hass.services.async_services()[DOMAIN]: hass.services.async_remove(DOMAIN, service_name) diff --git a/homeassistant/components/rainmachine/__init__.py b/homeassistant/components/rainmachine/__init__.py index 4d486c9c6aa..65648b8d44f 100644 --- a/homeassistant/components/rainmachine/__init__.py +++ b/homeassistant/components/rainmachine/__init__.py @@ -13,7 +13,7 @@ from regenmaschine.controller import Controller from regenmaschine.errors import RainMachineError, UnknownAPICallError import voluptuous as vol -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_DEVICE_ID, CONF_IP_ADDRESS, @@ -465,12 +465,7 @@ async def async_unload_entry( ) -> bool: """Unload an RainMachine config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state is ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): # If this is the last loaded instance of RainMachine, deregister any services # defined during integration setup: for service_name in ( diff --git a/homeassistant/components/simplisafe/__init__.py b/homeassistant/components/simplisafe/__init__.py index 2f19c5117a4..8a75baa69c6 100644 --- a/homeassistant/components/simplisafe/__init__.py +++ b/homeassistant/components/simplisafe/__init__.py @@ -39,7 +39,7 @@ from simplipy.websocket import ( ) import voluptuous as vol -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( ATTR_CODE, ATTR_DEVICE_ID, @@ -402,12 +402,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if unload_ok: hass.data[DOMAIN].pop(entry.entry_id) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): # If this is the last loaded instance of SimpliSafe, deregister any services # defined during integration setup: for service_name in SERVICES: diff --git a/homeassistant/components/tplink_omada/__init__.py b/homeassistant/components/tplink_omada/__init__.py index 06df118463b..7ea7fd95fef 100644 --- a/homeassistant/components/tplink_omada/__init__.py +++ b/homeassistant/components/tplink_omada/__init__.py @@ -11,7 +11,7 @@ from tplink_omada_client.exceptions import ( UnsupportedControllerVersion, ) -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady @@ -80,12 +80,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: OmadaConfigEntry) -> boo async def async_unload_entry(hass: HomeAssistant, entry: OmadaConfigEntry) -> bool: """Unload a config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): # This is the last loaded instance of Omada, deregister any services hass.services.async_remove(DOMAIN, "reconnect_client") diff --git a/homeassistant/components/xiaomi_aqara/__init__.py b/homeassistant/components/xiaomi_aqara/__init__.py index 579994aaf6b..6e4d143d84e 100644 --- a/homeassistant/components/xiaomi_aqara/__init__.py +++ b/homeassistant/components/xiaomi_aqara/__init__.py @@ -7,7 +7,7 @@ import voluptuous as vol from xiaomi_gateway import AsyncXiaomiGatewayMulticast, XiaomiGateway from homeassistant.components import persistent_notification -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( ATTR_DEVICE_ID, CONF_HOST, @@ -216,12 +216,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> if unload_ok: hass.data[DOMAIN][GATEWAYS_KEY].pop(config_entry.entry_id) - loaded_entries = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN) - if entry.state == ConfigEntryState.LOADED - ] - if len(loaded_entries) == 1: + if not hass.config_entries.async_loaded_entries(DOMAIN): # No gateways left, stop Xiaomi socket unsub_stop = hass.data[DOMAIN].pop(KEY_UNSUB_STOP) unsub_stop() diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index a103148e3b1..871b476227c 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -155,6 +155,8 @@ class ConfigEntryState(Enum): """An error occurred when trying to unload the entry""" SETUP_IN_PROGRESS = "setup_in_progress", False """The config entry is setting up.""" + UNLOAD_IN_PROGRESS = "unload_in_progress", False + """The config entry is being unloaded.""" _recoverable: bool @@ -955,18 +957,25 @@ class ConfigEntry[_DataT = Any]: ) return False + if domain_is_integration: + self._async_set_state(hass, ConfigEntryState.UNLOAD_IN_PROGRESS, None) try: result = await component.async_unload_entry(hass, self) assert isinstance(result, bool) - # Only adjust state if we unloaded the component - if domain_is_integration and result: - await self._async_process_on_unload(hass) - if hasattr(self, "runtime_data"): - object.__delattr__(self, "runtime_data") + # Only do side effects if we unloaded the integration + if domain_is_integration: + if result: + await self._async_process_on_unload(hass) + if hasattr(self, "runtime_data"): + object.__delattr__(self, "runtime_data") - self._async_set_state(hass, ConfigEntryState.NOT_LOADED, None) + self._async_set_state(hass, ConfigEntryState.NOT_LOADED, None) + else: + self._async_set_state( + hass, ConfigEntryState.FAILED_UNLOAD, "Unload failed" + ) except Exception as exc: _LOGGER.exception( @@ -2052,9 +2061,9 @@ class ConfigEntries: else: unload_success = await self.async_unload(entry_id, _lock=False) + del self._entries[entry.entry_id] await entry.async_remove(self.hass) - del self._entries[entry.entry_id] self.async_update_issues() self._async_schedule_save() diff --git a/tests/components/matter/test_init.py b/tests/components/matter/test_init.py index f6576689413..553358f12e3 100644 --- a/tests/components/matter/test_init.py +++ b/tests/components/matter/test_init.py @@ -502,7 +502,7 @@ async def test_issue_registry_invalid_version( ("stop_addon_side_effect", "entry_state"), [ (None, ConfigEntryState.NOT_LOADED), - (SupervisorError("Boom"), ConfigEntryState.LOADED), + (SupervisorError("Boom"), ConfigEntryState.FAILED_UNLOAD), ], ) async def test_stop_addon( diff --git a/tests/components/unifi/test_hub.py b/tests/components/unifi/test_hub.py index 5492f6fe0df..8b129d3d648 100644 --- a/tests/components/unifi/test_hub.py +++ b/tests/components/unifi/test_hub.py @@ -76,7 +76,7 @@ async def test_reset_fails( return_value=False, ): assert not await hass.config_entries.async_unload(config_entry_setup.entry_id) - assert config_entry_setup.state is ConfigEntryState.LOADED + assert config_entry_setup.state is ConfigEntryState.FAILED_UNLOAD @pytest.mark.usefixtures("mock_device_registry") diff --git a/tests/components/zwave_js/test_init.py b/tests/components/zwave_js/test_init.py index 4f858f3e545..c575066b57c 100644 --- a/tests/components/zwave_js/test_init.py +++ b/tests/components/zwave_js/test_init.py @@ -847,7 +847,7 @@ async def test_issue_registry( ("stop_addon_side_effect", "entry_state"), [ (None, ConfigEntryState.NOT_LOADED), - (SupervisorError("Boom"), ConfigEntryState.LOADED), + (SupervisorError("Boom"), ConfigEntryState.FAILED_UNLOAD), ], ) async def test_stop_addon( diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index bf2280790fa..acc79deb538 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -468,8 +468,8 @@ async def test_remove_entry( hass: HomeAssistant, entry: config_entries.ConfigEntry ) -> None: """Mock removing an entry.""" - # Check that the entry is not yet removed from config entries - assert hass.config_entries.async_get_entry(entry.entry_id) + # Check that the entry is no longer in the config entries + assert not hass.config_entries.async_get_entry(entry.entry_id) remove_entry_calls.append(None) entity = MockEntity(unique_id="1234", name="Test Entity") @@ -2623,7 +2623,7 @@ async def test_entry_setup_invalid_state( ("unload_result", "expected_result", "expected_state", "has_runtime_data"), [ (True, True, config_entries.ConfigEntryState.NOT_LOADED, False), - (False, False, config_entries.ConfigEntryState.LOADED, True), + (False, False, config_entries.ConfigEntryState.FAILED_UNLOAD, True), ], ) async def test_entry_unload( @@ -2648,7 +2648,7 @@ async def test_entry_unload( """Mock unload entry.""" unload_entry_calls.append(None) verify_runtime_data() - assert entry.state is config_entries.ConfigEntryState.LOADED + assert entry.state is config_entries.ConfigEntryState.UNLOAD_IN_PROGRESS return unload_result entry = MockConfigEntry(domain="comp", state=config_entries.ConfigEntryState.LOADED)