diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index d689d4548a9..34afc77e528 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -12,7 +12,7 @@ import weakref import attr from homeassistant import data_entry_flow, loader -from homeassistant.const import EVENT_HOMEASSISTANT_STARTED +from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback from homeassistant.exceptions import ( ConfigEntryAuthFailed, @@ -331,6 +331,17 @@ class ConfigEntry: else: self.state = ENTRY_STATE_SETUP_ERROR + async def async_shutdown(self) -> None: + """Call when Home Assistant is stopping.""" + self.async_cancel_retry_setup() + + @callback + def async_cancel_retry_setup(self) -> None: + """Cancel retry setup.""" + if self._async_cancel_retry_setup is not None: + self._async_cancel_retry_setup() + self._async_cancel_retry_setup = None + async def async_unload( self, hass: HomeAssistant, *, integration: loader.Integration | None = None ) -> bool: @@ -360,9 +371,7 @@ class ConfigEntry: return False if self.state != ENTRY_STATE_LOADED: - if self._async_cancel_retry_setup is not None: - self._async_cancel_retry_setup() - self._async_cancel_retry_setup = None + self.async_cancel_retry_setup() self.state = ENTRY_STATE_NOT_LOADED return True @@ -778,6 +787,12 @@ class ConfigEntries: return {"require_restart": not unload_success} + async def _async_shutdown(self, event: Event) -> None: + """Call when Home Assistant is stopping.""" + await asyncio.gather( + *[entry.async_shutdown() for entry in self._entries.values()] + ) + async def async_initialize(self) -> None: """Initialize config entry config.""" # Migrating for config entries stored before 0.73 @@ -787,6 +802,8 @@ class ConfigEntries: old_conf_migrate_func=_old_conf_migrator, ) + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown) + if config is None: self._entries = {} return diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 17131665240..46279fcb140 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -12,8 +12,12 @@ import voluptuous as vol from homeassistant import config as conf_util from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL -from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.const import ( + CONF_ENTITY_NAMESPACE, + CONF_SCAN_INTERVAL, + EVENT_HOMEASSISTANT_STOP, +) +from homeassistant.core import Event, HomeAssistant, ServiceCall, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( config_per_platform, @@ -118,6 +122,8 @@ class EntityComponent: This method must be run in the event loop. """ + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown) + self.config = config # Look in config for Domain, Domain 2, Domain 3 etc and load them @@ -322,3 +328,9 @@ class EntityComponent: scan_interval=scan_interval, entity_namespace=entity_namespace, ) + + async def _async_shutdown(self, event: Event) -> None: + """Call when Home Assistant is stopping.""" + await asyncio.gather( + *[platform.async_shutdown() for platform in chain(self._platforms.values())] + ) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 25996c81d9d..ef45b8dcd97 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -174,6 +174,18 @@ class EntityPlatform: await self._async_setup_platform(async_create_setup_task) + async def async_shutdown(self) -> None: + """Call when Home Assistant is stopping.""" + self.async_cancel_retry_setup() + self.async_unsub_polling() + + @callback + def async_cancel_retry_setup(self) -> None: + """Cancel retry setup.""" + if self._async_cancel_retry_setup is not None: + self._async_cancel_retry_setup() + self._async_cancel_retry_setup = None + async def async_setup_entry(self, config_entry: config_entries.ConfigEntry) -> bool: """Set up the platform from a config entry.""" # Store it so that we can save config entry ID in entity registry @@ -549,9 +561,7 @@ class EntityPlatform: This method must be run in the event loop. """ - if self._async_cancel_retry_setup is not None: - self._async_cancel_retry_setup() - self._async_cancel_retry_setup = None + self.async_cancel_retry_setup() if not self.entities: return @@ -560,10 +570,15 @@ class EntityPlatform: await asyncio.gather(*tasks) + self.async_unsub_polling() + self._setup_complete = False + + @callback + def async_unsub_polling(self) -> None: + """Stop polling.""" if self._async_unsub_polling is not None: self._async_unsub_polling() self._async_unsub_polling = None - self._setup_complete = False async def async_destroy(self) -> None: """Destroy an entity platform. diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 8d61ec7d509..1d18111b0d3 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -8,7 +8,11 @@ from unittest.mock import AsyncMock, Mock, patch import pytest import voluptuous as vol -from homeassistant.const import ENTITY_MATCH_ALL, ENTITY_MATCH_NONE +from homeassistant.const import ( + ENTITY_MATCH_ALL, + ENTITY_MATCH_NONE, + EVENT_HOMEASSISTANT_STOP, +) import homeassistant.core as ha from homeassistant.exceptions import PlatformNotReady from homeassistant.helpers import discovery @@ -487,3 +491,25 @@ async def test_register_entity_service(hass): DOMAIN, "hello", {"area_id": ENTITY_MATCH_NONE, "some": "data"}, blocking=True ) assert len(calls) == 2 + + +async def test_platforms_shutdown_on_stop(hass): + """Test that we shutdown platforms on stop.""" + platform1_setup = Mock(side_effect=[PlatformNotReady, PlatformNotReady, None]) + mock_integration(hass, MockModule("mod1")) + mock_entity_platform(hass, "test_domain.mod1", MockPlatform(platform1_setup)) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + + await component.async_setup({DOMAIN: {"platform": "mod1"}}) + await hass.async_block_till_done() + assert len(platform1_setup.mock_calls) == 1 + assert "test_domain.mod1" not in hass.config.components + + with patch.object( + component._platforms[DOMAIN], "async_shutdown" + ) as mock_async_shutdown: + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + + assert mock_async_shutdown.called diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index e842d5aa1ae..d24084ff517 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -685,6 +685,29 @@ async def test_reset_cancels_retry_setup_when_not_started(hass): assert ent_platform._async_cancel_retry_setup is None +async def test_stop_shutdown_cancels_retry_setup_and_interval_listener(hass): + """Test that shutdown will cancel scheduled a setup retry and interval listener.""" + async_setup_entry = Mock(side_effect=PlatformNotReady) + platform = MockPlatform(async_setup_entry=async_setup_entry) + config_entry = MockConfigEntry() + ent_platform = MockEntityPlatform( + hass, platform_name=config_entry.domain, platform=platform + ) + + with patch.object(entity_platform, "async_call_later") as mock_call_later: + assert not await ent_platform.async_setup_entry(config_entry) + + assert len(mock_call_later.mock_calls) == 1 + assert len(mock_call_later.return_value.mock_calls) == 0 + assert ent_platform._async_cancel_retry_setup is not None + + await ent_platform.async_shutdown() + + assert len(mock_call_later.return_value.mock_calls) == 1 + assert ent_platform._async_unsub_polling is None + assert ent_platform._async_cancel_retry_setup is None + + async def test_not_fails_with_adding_empty_entities_(hass): """Test for not fails on empty entities list.""" component = EntityComponent(_LOGGER, DOMAIN, hass) diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index dbfe48129c1..20ab5e67fef 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest from homeassistant import config_entries, data_entry_flow, loader -from homeassistant.const import EVENT_HOMEASSISTANT_STARTED +from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import CoreState, callback from homeassistant.exceptions import ( ConfigEntryAuthFailed, @@ -1405,7 +1405,7 @@ async def test_reload_entry_entity_registry_works(hass): assert len(mock_unload_entry.mock_calls) == 1 -async def test_unqiue_id_persisted(hass, manager): +async def test_unique_id_persisted(hass, manager): """Test that a unique ID is stored in the config entry.""" mock_setup_entry = AsyncMock(return_value=True) @@ -2667,3 +2667,40 @@ async def test_setup_raise_auth_failed_from_future_coordinator_update(hass, capl assert entry.state == config_entries.ENTRY_STATE_LOADED flows = hass.config_entries.flow.async_progress() assert len(flows) == 1 + + +async def test_initialize_and_shutdown(hass): + """Test we call the shutdown function at stop.""" + manager = config_entries.ConfigEntries(hass, {}) + + with patch.object(manager, "_async_shutdown") as mock_async_shutdown: + await manager.async_initialize() + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + + assert mock_async_shutdown.called + + +async def test_setup_retrying_during_shutdown(hass): + """Test if we shutdown an entry that is in retry mode.""" + entry = MockConfigEntry(domain="test") + + mock_setup_entry = AsyncMock(side_effect=ConfigEntryNotReady) + mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry)) + mock_entity_platform(hass, "config_flow.test", None) + + with patch("homeassistant.helpers.event.async_call_later") as mock_call: + await entry.async_setup(hass) + + assert entry.state == config_entries.ENTRY_STATE_SETUP_RETRY + assert len(mock_call.return_value.mock_calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + + assert len(mock_call.return_value.mock_calls) == 0 + + async_fire_time_changed(hass, dt.utcnow() + timedelta(hours=4)) + await hass.async_block_till_done() + + assert len(mock_call.return_value.mock_calls) == 0