Improve config entry state transitions when unloading and removing entries (#138522)

* Improve config entry state transitions when unloading and removing entries

* Update integrations which check for a single loaded entry

* Update tests checking state after unload fails

* Update homeassistant/config_entries.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2025-02-17 18:16:57 +01:00 committed by GitHub
parent ff16e587e8
commit e0795e6d07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 46 additions and 98 deletions

View File

@ -7,7 +7,7 @@ from dataclasses import dataclass
from adguardhome import AdGuardHome, AdGuardHomeConnectionError from adguardhome import AdGuardHome, AdGuardHomeConnectionError
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_NAME, CONF_NAME,
@ -123,12 +123,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> b
async def async_unload_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> bool:
"""Unload AdGuard Home config entry.""" """Unload AdGuard Home config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# This is the last loaded instance of AdGuard, deregister any services # This is the last loaded instance of AdGuard, deregister any services
hass.services.async_remove(DOMAIN, SERVICE_ADD_URL) hass.services.async_remove(DOMAIN, SERVICE_ADD_URL)
hass.services.async_remove(DOMAIN, SERVICE_REMOVE_URL) hass.services.async_remove(DOMAIN, SERVICE_REMOVE_URL)

View File

@ -10,7 +10,7 @@ from google.oauth2.credentials import Credentials
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import ( from homeassistant.core import (
HomeAssistant, HomeAssistant,
@ -99,12 +99,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
for service_name in hass.services.async_services_for_domain(DOMAIN): for service_name in hass.services.async_services_for_domain(DOMAIN):
hass.services.async_remove(DOMAIN, service_name) hass.services.async_remove(DOMAIN, service_name)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_NAME, Platform from homeassistant.const import CONF_NAME, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers import config_validation as cv, discovery
@ -59,12 +59,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: GoogleMailConfigEntry) -
async def async_unload_entry(hass: HomeAssistant, entry: GoogleMailConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: GoogleMailConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
for service_name in hass.services.async_services_for_domain(DOMAIN): for service_name in hass.services.async_services_for_domain(DOMAIN):
hass.services.async_remove(DOMAIN, service_name) hass.services.async_remove(DOMAIN, service_name)

View File

@ -12,7 +12,7 @@ from gspread.exceptions import APIError
from gspread.utils import ValueInputOption from gspread.utils import ValueInputOption
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import ( from homeassistant.exceptions import (
@ -81,12 +81,7 @@ async def async_unload_entry(
hass: HomeAssistant, entry: GoogleSheetsConfigEntry hass: HomeAssistant, entry: GoogleSheetsConfigEntry
) -> bool: ) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
for service_name in hass.services.async_services_for_domain(DOMAIN): for service_name in hass.services.async_services_for_domain(DOMAIN):
hass.services.async_remove(DOMAIN, service_name) hass.services.async_remove(DOMAIN, service_name)

View File

@ -11,7 +11,7 @@ from aioguardian import Client
from aioguardian.errors import GuardianError from aioguardian.errors import GuardianError
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
ATTR_DEVICE_ID, ATTR_DEVICE_ID,
CONF_DEVICE_ID, CONF_DEVICE_ID,
@ -247,12 +247,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if unload_ok: if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# If this is the last loaded instance of Guardian, deregister any services # If this is the last loaded instance of Guardian, deregister any services
# defined during integration setup: # defined during integration setup:
for service_name in SERVICES: for service_name in SERVICES:

View File

@ -19,7 +19,7 @@ from aiolookin import (
) )
from aiolookin.models import UDPCommandType, UDPEvent from aiolookin.models import UDPCommandType, UDPEvent
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, Platform from homeassistant.const import CONF_HOST, Platform
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
@ -192,12 +192,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
manager: LookinUDPManager = hass.data[DOMAIN][UDP_MANAGER] manager: LookinUDPManager = hass.data[DOMAIN][UDP_MANAGER]
await manager.async_stop() await manager.async_stop()
return unload_ok return unload_ok

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
from motionblinds import AsyncMotionMulticast from motionblinds import AsyncMotionMulticast
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_HOST, EVENT_HOMEASSISTANT_STOP from homeassistant.const import CONF_API_KEY, CONF_HOST, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
@ -124,12 +124,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
multicast.Unregister_motion_gateway(config_entry.data[CONF_HOST]) multicast.Unregister_motion_gateway(config_entry.data[CONF_HOST])
hass.data[DOMAIN].pop(config_entry.entry_id) hass.data[DOMAIN].pop(config_entry.entry_id)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# No motion gateways left, stop Motion multicast # No motion gateways left, stop Motion multicast
unsub_stop = hass.data[DOMAIN].pop(KEY_UNSUB_STOP) unsub_stop = hass.data[DOMAIN].pop(KEY_UNSUB_STOP)
unsub_stop() unsub_stop()

View File

@ -6,7 +6,6 @@ from aiohttp.cookiejar import CookieJar
import eternalegypt import eternalegypt
from eternalegypt.eternalegypt import SMS from eternalegypt.eternalegypt import SMS
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, Platform from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
@ -117,12 +116,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: NetgearLTEConfigEntry) -
async def async_unload_entry(hass: HomeAssistant, entry: NetgearLTEConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: NetgearLTEConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
hass.data.pop(DOMAIN, None) hass.data.pop(DOMAIN, None)
for service_name in hass.services.async_services()[DOMAIN]: for service_name in hass.services.async_services()[DOMAIN]:
hass.services.async_remove(DOMAIN, service_name) hass.services.async_remove(DOMAIN, service_name)

View File

@ -13,7 +13,7 @@ from regenmaschine.controller import Controller
from regenmaschine.errors import RainMachineError, UnknownAPICallError from regenmaschine.errors import RainMachineError, UnknownAPICallError
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE_ID, CONF_DEVICE_ID,
CONF_IP_ADDRESS, CONF_IP_ADDRESS,
@ -465,12 +465,7 @@ async def async_unload_entry(
) -> bool: ) -> bool:
"""Unload an RainMachine config entry.""" """Unload an RainMachine config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state is ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# If this is the last loaded instance of RainMachine, deregister any services # If this is the last loaded instance of RainMachine, deregister any services
# defined during integration setup: # defined during integration setup:
for service_name in ( for service_name in (

View File

@ -39,7 +39,7 @@ from simplipy.websocket import (
) )
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
ATTR_CODE, ATTR_CODE,
ATTR_DEVICE_ID, ATTR_DEVICE_ID,
@ -402,12 +402,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if unload_ok: if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# If this is the last loaded instance of SimpliSafe, deregister any services # If this is the last loaded instance of SimpliSafe, deregister any services
# defined during integration setup: # defined during integration setup:
for service_name in SERVICES: for service_name in SERVICES:

View File

@ -11,7 +11,7 @@ from tplink_omada_client.exceptions import (
UnsupportedControllerVersion, UnsupportedControllerVersion,
) )
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
@ -80,12 +80,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: OmadaConfigEntry) -> boo
async def async_unload_entry(hass: HomeAssistant, entry: OmadaConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: OmadaConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# This is the last loaded instance of Omada, deregister any services # This is the last loaded instance of Omada, deregister any services
hass.services.async_remove(DOMAIN, "reconnect_client") hass.services.async_remove(DOMAIN, "reconnect_client")

View File

@ -7,7 +7,7 @@ import voluptuous as vol
from xiaomi_gateway import AsyncXiaomiGatewayMulticast, XiaomiGateway from xiaomi_gateway import AsyncXiaomiGatewayMulticast, XiaomiGateway
from homeassistant.components import persistent_notification from homeassistant.components import persistent_notification
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
ATTR_DEVICE_ID, ATTR_DEVICE_ID,
CONF_HOST, CONF_HOST,
@ -216,12 +216,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
if unload_ok: if unload_ok:
hass.data[DOMAIN][GATEWAYS_KEY].pop(config_entry.entry_id) hass.data[DOMAIN][GATEWAYS_KEY].pop(config_entry.entry_id)
loaded_entries = [ if not hass.config_entries.async_loaded_entries(DOMAIN):
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# No gateways left, stop Xiaomi socket # No gateways left, stop Xiaomi socket
unsub_stop = hass.data[DOMAIN].pop(KEY_UNSUB_STOP) unsub_stop = hass.data[DOMAIN].pop(KEY_UNSUB_STOP)
unsub_stop() unsub_stop()

View File

@ -155,6 +155,8 @@ class ConfigEntryState(Enum):
"""An error occurred when trying to unload the entry""" """An error occurred when trying to unload the entry"""
SETUP_IN_PROGRESS = "setup_in_progress", False SETUP_IN_PROGRESS = "setup_in_progress", False
"""The config entry is setting up.""" """The config entry is setting up."""
UNLOAD_IN_PROGRESS = "unload_in_progress", False
"""The config entry is being unloaded."""
_recoverable: bool _recoverable: bool
@ -955,18 +957,25 @@ class ConfigEntry[_DataT = Any]:
) )
return False return False
if domain_is_integration:
self._async_set_state(hass, ConfigEntryState.UNLOAD_IN_PROGRESS, None)
try: try:
result = await component.async_unload_entry(hass, self) result = await component.async_unload_entry(hass, self)
assert isinstance(result, bool) assert isinstance(result, bool)
# Only adjust state if we unloaded the component # Only do side effects if we unloaded the integration
if domain_is_integration and result: if domain_is_integration:
await self._async_process_on_unload(hass) if result:
if hasattr(self, "runtime_data"): await self._async_process_on_unload(hass)
object.__delattr__(self, "runtime_data") if hasattr(self, "runtime_data"):
object.__delattr__(self, "runtime_data")
self._async_set_state(hass, ConfigEntryState.NOT_LOADED, None) self._async_set_state(hass, ConfigEntryState.NOT_LOADED, None)
else:
self._async_set_state(
hass, ConfigEntryState.FAILED_UNLOAD, "Unload failed"
)
except Exception as exc: except Exception as exc:
_LOGGER.exception( _LOGGER.exception(
@ -2052,9 +2061,9 @@ class ConfigEntries:
else: else:
unload_success = await self.async_unload(entry_id, _lock=False) unload_success = await self.async_unload(entry_id, _lock=False)
del self._entries[entry.entry_id]
await entry.async_remove(self.hass) await entry.async_remove(self.hass)
del self._entries[entry.entry_id]
self.async_update_issues() self.async_update_issues()
self._async_schedule_save() self._async_schedule_save()

View File

@ -502,7 +502,7 @@ async def test_issue_registry_invalid_version(
("stop_addon_side_effect", "entry_state"), ("stop_addon_side_effect", "entry_state"),
[ [
(None, ConfigEntryState.NOT_LOADED), (None, ConfigEntryState.NOT_LOADED),
(SupervisorError("Boom"), ConfigEntryState.LOADED), (SupervisorError("Boom"), ConfigEntryState.FAILED_UNLOAD),
], ],
) )
async def test_stop_addon( async def test_stop_addon(

View File

@ -76,7 +76,7 @@ async def test_reset_fails(
return_value=False, return_value=False,
): ):
assert not await hass.config_entries.async_unload(config_entry_setup.entry_id) assert not await hass.config_entries.async_unload(config_entry_setup.entry_id)
assert config_entry_setup.state is ConfigEntryState.LOADED assert config_entry_setup.state is ConfigEntryState.FAILED_UNLOAD
@pytest.mark.usefixtures("mock_device_registry") @pytest.mark.usefixtures("mock_device_registry")

View File

@ -847,7 +847,7 @@ async def test_issue_registry(
("stop_addon_side_effect", "entry_state"), ("stop_addon_side_effect", "entry_state"),
[ [
(None, ConfigEntryState.NOT_LOADED), (None, ConfigEntryState.NOT_LOADED),
(SupervisorError("Boom"), ConfigEntryState.LOADED), (SupervisorError("Boom"), ConfigEntryState.FAILED_UNLOAD),
], ],
) )
async def test_stop_addon( async def test_stop_addon(

View File

@ -468,8 +468,8 @@ async def test_remove_entry(
hass: HomeAssistant, entry: config_entries.ConfigEntry hass: HomeAssistant, entry: config_entries.ConfigEntry
) -> None: ) -> None:
"""Mock removing an entry.""" """Mock removing an entry."""
# Check that the entry is not yet removed from config entries # Check that the entry is no longer in the config entries
assert hass.config_entries.async_get_entry(entry.entry_id) assert not hass.config_entries.async_get_entry(entry.entry_id)
remove_entry_calls.append(None) remove_entry_calls.append(None)
entity = MockEntity(unique_id="1234", name="Test Entity") entity = MockEntity(unique_id="1234", name="Test Entity")
@ -2623,7 +2623,7 @@ async def test_entry_setup_invalid_state(
("unload_result", "expected_result", "expected_state", "has_runtime_data"), ("unload_result", "expected_result", "expected_state", "has_runtime_data"),
[ [
(True, True, config_entries.ConfigEntryState.NOT_LOADED, False), (True, True, config_entries.ConfigEntryState.NOT_LOADED, False),
(False, False, config_entries.ConfigEntryState.LOADED, True), (False, False, config_entries.ConfigEntryState.FAILED_UNLOAD, True),
], ],
) )
async def test_entry_unload( async def test_entry_unload(
@ -2648,7 +2648,7 @@ async def test_entry_unload(
"""Mock unload entry.""" """Mock unload entry."""
unload_entry_calls.append(None) unload_entry_calls.append(None)
verify_runtime_data() verify_runtime_data()
assert entry.state is config_entries.ConfigEntryState.LOADED assert entry.state is config_entries.ConfigEntryState.UNLOAD_IN_PROGRESS
return unload_result return unload_result
entry = MockConfigEntry(domain="comp", state=config_entries.ConfigEntryState.LOADED) entry = MockConfigEntry(domain="comp", state=config_entries.ConfigEntryState.LOADED)