diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index bd985517ca7..902fa0d03f2 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -12,6 +12,8 @@ from types import MappingProxyType, MethodType from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import weakref +import async_timeout + from . import data_entry_flow, loader from .backports.enum import StrEnum from .components import persistent_notification @@ -19,7 +21,7 @@ from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platfo from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError from .helpers import device_registry, entity_registry, storage -from .helpers.dispatcher import async_dispatcher_send +from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send from .helpers.event import async_call_later from .helpers.frame import report from .helpers.typing import UNDEFINED, ConfigType, DiscoveryInfoType, UndefinedType @@ -1239,6 +1241,36 @@ class ConfigEntries: await entry.async_setup(self.hass, integration=integration) return True + async def async_wait_for_states( + self, entry: ConfigEntry, states: set[ConfigEntryState], timeout: float = 60.0 + ) -> ConfigEntryState: + """Wait for the setup of an entry to reach one of the supplied states state. + + Returns the state the entry reached or raises asyncio.TimeoutError if the + entry did not reach one of the supplied states within the timeout. + """ + state_reached_future: asyncio.Future[ConfigEntryState] = asyncio.Future() + + @callback + def _async_entry_changed( + change: ConfigEntryChange, event_entry: ConfigEntry + ) -> None: + if ( + event_entry is entry + and change is ConfigEntryChange.UPDATED + and entry.state in states + ): + state_reached_future.set_result(entry.state) + + unsub = async_dispatcher_connect( + self.hass, SIGNAL_CONFIG_ENTRY_CHANGED, _async_entry_changed + ) + try: + async with async_timeout.timeout(timeout): + return await state_reached_future + finally: + unsub() + async def async_unload_platforms( self, entry: ConfigEntry, platforms: Iterable[Platform | str] ) -> bool: diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 99e26be6d75..7684e9ff260 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3330,3 +3330,110 @@ async def test_reauth(hass): entry2.async_start_reauth(hass, {"extra_context": "some_extra_context"}) await hass.async_block_till_done() assert len(hass.config_entries.flow.async_progress()) == 2 + + +async def test_wait_for_loading_entry(hass): + """Test waiting for entry to be set up.""" + + entry = MockConfigEntry(title="test_title", domain="test") + + mock_setup_entry = AsyncMock(return_value=True) + mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry)) + mock_entity_platform(hass, "config_flow.test", None) + + await entry.async_setup(hass) + await hass.async_block_till_done() + + flow = hass.config_entries.flow + + async def _load_entry(): + # Mock config entry + assert await async_setup_component(hass, "test", {}) + + entry = MockConfigEntry(title="test_title", domain="test") + entry.add_to_hass(hass) + flow = hass.config_entries.flow + with patch.object(flow, "async_init", wraps=flow.async_init): + hass.async_add_job(_load_entry) + new_state = await hass.config_entries.async_wait_for_states( + entry, + { + config_entries.ConfigEntryState.LOADED, + config_entries.ConfigEntryState.SETUP_ERROR, + }, + timeout=1.0, + ) + assert new_state is config_entries.ConfigEntryState.LOADED + assert entry.state is config_entries.ConfigEntryState.LOADED + + +async def test_wait_for_loading_failed_entry(hass): + """Test waiting for entry to be set up that fails loading.""" + + entry = MockConfigEntry(title="test_title", domain="test") + + mock_setup_entry = AsyncMock(side_effect=HomeAssistantError) + mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry)) + mock_entity_platform(hass, "config_flow.test", None) + + await entry.async_setup(hass) + await hass.async_block_till_done() + + flow = hass.config_entries.flow + + async def _load_entry(): + # Mock config entry + assert await async_setup_component(hass, "test", {}) + + entry = MockConfigEntry(title="test_title", domain="test") + entry.add_to_hass(hass) + flow = hass.config_entries.flow + with patch.object(flow, "async_init", wraps=flow.async_init): + hass.async_add_job(_load_entry) + new_state = await hass.config_entries.async_wait_for_states( + entry, + { + config_entries.ConfigEntryState.LOADED, + config_entries.ConfigEntryState.SETUP_ERROR, + }, + timeout=1.0, + ) + assert new_state is config_entries.ConfigEntryState.SETUP_ERROR + assert entry.state is config_entries.ConfigEntryState.SETUP_ERROR + + +async def test_wait_for_loading_timeout(hass): + """Test waiting for entry to be set up that fails with a timeout.""" + + async def _async_setup_entry(hass, entry): + await asyncio.sleep(1) + return True + + entry = MockConfigEntry(title="test_title", domain="test") + + mock_integration(hass, MockModule("test", async_setup_entry=_async_setup_entry)) + mock_entity_platform(hass, "config_flow.test", None) + + await entry.async_setup(hass) + await hass.async_block_till_done() + + flow = hass.config_entries.flow + + async def _load_entry(): + # Mock config entry + assert await async_setup_component(hass, "test", {}) + + entry = MockConfigEntry(title="test_title", domain="test") + entry.add_to_hass(hass) + flow = hass.config_entries.flow + with patch.object(flow, "async_init", wraps=flow.async_init): + hass.async_add_job(_load_entry) + with pytest.raises(asyncio.exceptions.TimeoutError): + await hass.config_entries.async_wait_for_states( + entry, + { + config_entries.ConfigEntryState.LOADED, + config_entries.ConfigEntryState.SETUP_ERROR, + }, + timeout=0.1, + )