From 1a274adc28a9e69fe06d3ebaf1a6e168e844ecda Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 17 Nov 2022 21:52:57 +0100 Subject: [PATCH] Add config_entries.async_wait_component (#76980) Co-authored-by: thecode --- .../components/device_automation/trigger.py | 24 ++++++- .../components/webostv/device_trigger.py | 7 -- homeassistant/components/webostv/helpers.py | 14 ---- homeassistant/components/zha/__init__.py | 5 +- homeassistant/components/zha/core/const.py | 1 - .../components/zha/device_trigger.py | 24 ++++--- homeassistant/config_entries.py | 18 ++++- homeassistant/setup.py | 24 ++++++- .../components/shelly/test_device_trigger.py | 26 ++++++-- .../components/webostv/test_device_trigger.py | 39 ----------- tests/test_config_entries.py | 66 ++++++++++++++++++- 11 files changed, 157 insertions(+), 91 deletions(-) diff --git a/homeassistant/components/device_automation/trigger.py b/homeassistant/components/device_automation/trigger.py index bd72b24d844..05f2f79ff28 100644 --- a/homeassistant/components/device_automation/trigger.py +++ b/homeassistant/components/device_automation/trigger.py @@ -5,8 +5,9 @@ from typing import Any, Protocol, cast import voluptuous as vol -from homeassistant.const import CONF_DOMAIN +from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN from homeassistant.core import CALLBACK_TYPE, HomeAssistant +from homeassistant.helpers import device_registry as dr from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType @@ -63,6 +64,27 @@ async def async_validate_trigger_config( ) if not hasattr(platform, "async_validate_trigger_config"): return cast(ConfigType, platform.TRIGGER_SCHEMA(config)) + + # Only call the dynamic validator if the relevant config entry is loaded + registry = dr.async_get(hass) + if not (device := registry.async_get(config[CONF_DEVICE_ID])): + raise InvalidDeviceAutomationConfig + + device_config_entry = None + for entry_id in device.config_entries: + if not (entry := hass.config_entries.async_get_entry(entry_id)): + continue + if entry.domain != config[CONF_DOMAIN]: + continue + device_config_entry = entry + break + + if not device_config_entry: + raise InvalidDeviceAutomationConfig + + if not await hass.config_entries.async_wait_component(device_config_entry): + return config + return await platform.async_validate_trigger_config(hass, config) except InvalidDeviceAutomationConfig as err: raise vol.Invalid(str(err) or "Invalid trigger configuration") from err diff --git a/homeassistant/components/webostv/device_trigger.py b/homeassistant/components/webostv/device_trigger.py index ef3e74a7daa..590cbc19de8 100644 --- a/homeassistant/components/webostv/device_trigger.py +++ b/homeassistant/components/webostv/device_trigger.py @@ -18,7 +18,6 @@ from .const import DOMAIN from .helpers import ( async_get_client_wrapper_by_device_entry, async_get_device_entry_by_device_id, - async_is_device_config_entry_not_loaded, ) from .triggers.turn_on import PLATFORM_TYPE as TURN_ON_PLATFORM_TYPE @@ -36,12 +35,6 @@ async def async_validate_trigger_config( """Validate config.""" config = TRIGGER_SCHEMA(config) - try: - if async_is_device_config_entry_not_loaded(hass, config[CONF_DEVICE_ID]): - return config - except ValueError as err: - raise InvalidDeviceAutomationConfig(err) from err - if config[CONF_TYPE] == TURN_ON_PLATFORM_TYPE: device_id = config[CONF_DEVICE_ID] try: diff --git a/homeassistant/components/webostv/helpers.py b/homeassistant/components/webostv/helpers.py index 0ee3805f42f..4f1ab9dfebe 100644 --- a/homeassistant/components/webostv/helpers.py +++ b/homeassistant/components/webostv/helpers.py @@ -1,7 +1,6 @@ """Helper functions for webOS Smart TV.""" from __future__ import annotations -from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.device_registry import DeviceEntry @@ -26,19 +25,6 @@ def async_get_device_entry_by_device_id( return device -@callback -def async_is_device_config_entry_not_loaded( - hass: HomeAssistant, device_id: str -) -> bool: - """Return whether device's config entries are not loaded.""" - device = async_get_device_entry_by_device_id(hass, device_id) - return any( - (entry := hass.config_entries.async_get_entry(entry_id)) - and entry.state != ConfigEntryState.LOADED - for entry_id in device.config_entries - ) - - @callback def async_get_device_id_from_entity_id(hass: HomeAssistant, entity_id: str) -> str: """ diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index 0a7d43120f7..70b9dfd9b46 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -35,7 +35,6 @@ from .core.const import ( DOMAIN, PLATFORMS, SIGNAL_ADD_ENTITIES, - ZHA_DEVICES_LOADED_EVENT, RadioType, ) from .core.discovery import GROUP_PROBE @@ -76,7 +75,7 @@ _LOGGER = logging.getLogger(__name__) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up ZHA from config.""" - hass.data[DATA_ZHA] = {ZHA_DEVICES_LOADED_EVENT: asyncio.Event()} + hass.data[DATA_ZHA] = {} if DOMAIN in config: conf = config[DOMAIN] @@ -110,7 +109,6 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b zha_gateway = ZHAGateway(hass, config, config_entry) await zha_gateway.async_initialize() - hass.data[DATA_ZHA][ZHA_DEVICES_LOADED_EVENT].set() device_registry = dr.async_get(hass) device_registry.async_get_or_create( @@ -143,7 +141,6 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> """Unload ZHA config entry.""" zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] await zha_gateway.shutdown() - hass.data[DATA_ZHA][ZHA_DEVICES_LOADED_EVENT].clear() GROUP_PROBE.cleanup() api.async_unload_api(hass) diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index b9871a1f2ab..180be2d8830 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -395,7 +395,6 @@ ZHA_GW_MSG_GROUP_REMOVED = "group_removed" ZHA_GW_MSG_LOG_ENTRY = "log_entry" ZHA_GW_MSG_LOG_OUTPUT = "log_output" ZHA_GW_MSG_RAW_INIT = "raw_device_initialized" -ZHA_DEVICES_LOADED_EVENT = "zha_devices_loaded_event" class Strobe(t.enum8): diff --git a/homeassistant/components/zha/device_trigger.py b/homeassistant/components/zha/device_trigger.py index 94b94b89e40..6f78aa6f858 100644 --- a/homeassistant/components/zha/device_trigger.py +++ b/homeassistant/components/zha/device_trigger.py @@ -14,7 +14,7 @@ from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType from . import DOMAIN as ZHA_DOMAIN -from .core.const import DATA_ZHA, ZHA_DEVICES_LOADED_EVENT, ZHA_EVENT +from .core.const import ZHA_EVENT from .core.helpers import async_get_zha_device CONF_SUBTYPE = "subtype" @@ -32,18 +32,16 @@ async def async_validate_trigger_config( """Validate config.""" config = TRIGGER_SCHEMA(config) - if ZHA_DOMAIN in hass.config.components: - await hass.data[DATA_ZHA][ZHA_DEVICES_LOADED_EVENT].wait() - trigger = (config[CONF_TYPE], config[CONF_SUBTYPE]) - try: - zha_device = async_get_zha_device(hass, config[CONF_DEVICE_ID]) - except (KeyError, AttributeError, IntegrationError) as err: - raise InvalidDeviceAutomationConfig from err - if ( - zha_device.device_automation_triggers is None - or trigger not in zha_device.device_automation_triggers - ): - raise InvalidDeviceAutomationConfig + trigger = (config[CONF_TYPE], config[CONF_SUBTYPE]) + try: + zha_device = async_get_zha_device(hass, config[CONF_DEVICE_ID]) + except (KeyError, AttributeError, IntegrationError) as err: + raise InvalidDeviceAutomationConfig from err + if ( + zha_device.device_automation_triggers is None + or trigger not in zha_device.device_automation_triggers + ): + raise InvalidDeviceAutomationConfig return config diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index ddef5d7f226..eb6fcf855a3 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -26,7 +26,7 @@ from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send from .helpers.event import async_call_later from .helpers.frame import report from .helpers.typing import UNDEFINED, ConfigType, DiscoveryInfoType, UndefinedType -from .setup import async_process_deps_reqs, async_setup_component +from .setup import DATA_SETUP_DONE, async_process_deps_reqs, async_setup_component from .util import uuid as uuid_util from .util.decorator import Registry @@ -1314,6 +1314,22 @@ class ConfigEntries: """Return data to save.""" return {"entries": [entry.as_dict() for entry in self._entries.values()]} + async def async_wait_component(self, entry: ConfigEntry) -> bool: + """Wait for an entry's component to load and return if the entry is loaded. + + This is primarily intended for existing config entries which are loaded at + startup, awaiting this function will block until the component and all its + config entries are loaded. + Config entries which are created after Home Assistant is started can't be waited + for, the function will just return if the config entry is loaded or not. + """ + if setup_event := self.hass.data.get(DATA_SETUP_DONE, {}).get(entry.domain): + await setup_event.wait() + # The component was not loaded. + if entry.domain not in self.hass.config.components: + return False + return entry.state == ConfigEntryState.LOADED + async def _old_conf_migrator(old_config: dict[str, Any]) -> dict[str, Any]: """Migrate the pre-0.73 config format to the latest version.""" diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 2be22910a08..1172ab00bd7 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -29,11 +29,27 @@ ATTR_COMPONENT = "component" BASE_PLATFORMS = {platform.value for platform in Platform} +# DATA_SETUP is a dict[str, asyncio.Task[bool]], indicating domains which are currently +# being setup or which failed to setup +# - Tasks are added to DATA_SETUP by `async_setup_component`, the key is the domain being setup +# and the Task is the `_async_setup_component` helper. +# - Tasks are removed from DATA_SETUP if setup was successful, that is, the task returned True +DATA_SETUP = "setup_tasks" + +# DATA_SETUP_DONE is a dict [str, asyncio.Event], indicating components which will be setup +# - Events are added to DATA_SETUP_DONE during bootstrap by async_set_domains_to_be_loaded, +# the key is the domain which will be loaded +# - Events are set and removed from DATA_SETUP_DONE when async_setup_component is finished, +# regardless of if the setup was successful or not. DATA_SETUP_DONE = "setup_done" + +# DATA_SETUP_DONE is a dict [str, datetime], indicating when an attempt to setup a component +# started DATA_SETUP_STARTED = "setup_started" + +# DATA_SETUP_TIME is a dict [str, timedelta], indicating how time was spent setting up a component DATA_SETUP_TIME = "setup_time" -DATA_SETUP = "setup_tasks" DATA_DEPS_REQS = "deps_reqs_processed" SLOW_SETUP_WARNING = 10 @@ -44,7 +60,9 @@ SLOW_SETUP_MAX_WAIT = 300 def async_set_domains_to_be_loaded(hass: core.HomeAssistant, domains: set[str]) -> None: """Set domains that are going to be loaded from the config. - This will allow us to properly handle after_dependencies. + This allow us to: + - Properly handle after_dependencies. + - Keep track of domains which will load but have not yet finished loading """ hass.data[DATA_SETUP_DONE] = {domain: asyncio.Event() for domain in domains} @@ -265,7 +283,7 @@ async def _async_setup_component( await asyncio.sleep(0) await hass.config_entries.flow.async_wait_init_flow_finish(domain) - # Add to components before the async_setup + # Add to components before the entry.async_setup # call to avoid a deadlock when forwarding platforms hass.config.components.add(domain) diff --git a/tests/components/shelly/test_device_trigger.py b/tests/components/shelly/test_device_trigger.py index d5881696bf6..ec74745da15 100644 --- a/tests/components/shelly/test_device_trigger.py +++ b/tests/components/shelly/test_device_trigger.py @@ -240,9 +240,14 @@ async def test_if_fires_on_click_event_rpc_device(hass, calls, mock_rpc_device): assert calls[0].data["some"] == "test_trigger_single_push" -async def test_validate_trigger_block_device_not_ready(hass, calls, mock_block_device): +async def test_validate_trigger_block_device_not_ready( + hass, calls, mock_block_device, monkeypatch +): """Test validate trigger config when block device is not ready.""" - await init_integration(hass, 1) + monkeypatch.setattr(mock_block_device, "initialized", False) + entry = await init_integration(hass, 1) + dev_reg = async_get_dev_reg(hass) + device = async_entries_for_config_entry(dev_reg, entry.entry_id)[0] assert await async_setup_component( hass, @@ -253,7 +258,7 @@ async def test_validate_trigger_block_device_not_ready(hass, calls, mock_block_d "trigger": { CONF_PLATFORM: "device", CONF_DOMAIN: DOMAIN, - CONF_DEVICE_ID: "device_not_ready", + CONF_DEVICE_ID: device.id, CONF_TYPE: "single", CONF_SUBTYPE: "button1", }, @@ -266,7 +271,7 @@ async def test_validate_trigger_block_device_not_ready(hass, calls, mock_block_d }, ) message = { - CONF_DEVICE_ID: "device_not_ready", + CONF_DEVICE_ID: device.id, ATTR_CLICK_TYPE: "single", ATTR_CHANNEL: 1, } @@ -277,8 +282,15 @@ async def test_validate_trigger_block_device_not_ready(hass, calls, mock_block_d assert calls[0].data["some"] == "test_trigger_single_click" -async def test_validate_trigger_rpc_device_not_ready(hass, calls, mock_rpc_device): +async def test_validate_trigger_rpc_device_not_ready( + hass, calls, mock_rpc_device, monkeypatch +): """Test validate trigger config when RPC device is not ready.""" + monkeypatch.setattr(mock_rpc_device, "initialized", False) + entry = await init_integration(hass, 2) + dev_reg = async_get_dev_reg(hass) + device = async_entries_for_config_entry(dev_reg, entry.entry_id)[0] + assert await async_setup_component( hass, automation.DOMAIN, @@ -288,7 +300,7 @@ async def test_validate_trigger_rpc_device_not_ready(hass, calls, mock_rpc_devic "trigger": { CONF_PLATFORM: "device", CONF_DOMAIN: DOMAIN, - CONF_DEVICE_ID: "device_not_ready", + CONF_DEVICE_ID: device.id, CONF_TYPE: "single_push", CONF_SUBTYPE: "button1", }, @@ -301,7 +313,7 @@ async def test_validate_trigger_rpc_device_not_ready(hass, calls, mock_rpc_devic }, ) message = { - CONF_DEVICE_ID: "device_not_ready", + CONF_DEVICE_ID: device.id, ATTR_CLICK_TYPE: "single_push", ATTR_CHANNEL: 1, } diff --git a/tests/components/webostv/test_device_trigger.py b/tests/components/webostv/test_device_trigger.py index db15ce3a592..befa62340a5 100644 --- a/tests/components/webostv/test_device_trigger.py +++ b/tests/components/webostv/test_device_trigger.py @@ -98,41 +98,6 @@ async def test_if_fires_on_turn_on_request(hass, calls, client): assert calls[1].data["id"] == 0 -async def test_get_triggers_for_invalid_device_id(hass, caplog): - """Test error raised for invalid shelly device_id.""" - await async_setup_component(hass, "persistent_notification", {}) - - assert await async_setup_component( - hass, - automation.DOMAIN, - { - automation.DOMAIN: [ - { - "trigger": { - "platform": "device", - "domain": DOMAIN, - "device_id": "invalid_device_id", - "type": "webostv.turn_on", - }, - "action": { - "service": "test.automation", - "data_template": { - "some": "{{ trigger.invalid_device }}", - "id": "{{ trigger.id }}", - }, - }, - } - ] - }, - ) - await hass.async_block_till_done() - - assert ( - "Invalid config for [automation]: Device invalid_device_id is not a valid webostv device" - in caplog.text - ) - - async def test_failure_scenarios(hass, client): """Test failure scenarios.""" await setup_webostv(hass) @@ -173,7 +138,3 @@ async def test_failure_scenarios(hass, client): # Test that device id from non webostv domain raises exception with pytest.raises(InvalidDeviceAutomationConfig): await device_trigger.async_validate_trigger_config(hass, config) - - # Test no exception if device is not loaded - await hass.config_entries.async_unload(entry.entry_id) - assert await device_trigger.async_validate_trigger_config(hass, config) == config diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 28c3f9c2803..bdcd5a37651 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -25,7 +25,7 @@ from homeassistant.exceptions import ( ) from homeassistant.helpers import entity_registry as er from homeassistant.helpers.update_coordinator import DataUpdateCoordinator -from homeassistant.setup import async_setup_component +from homeassistant.setup import async_set_domains_to_be_loaded, async_setup_component from homeassistant.util import dt from tests.common import ( @@ -3461,3 +3461,67 @@ async def test_get_active_flows(hass): iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_USER})), None ) assert active_user_flow is None + + +async def test_async_wait_component_dynamic(hass: HomeAssistant): + """Test async_wait_component for a config entry which is dynamically loaded.""" + entry = MockConfigEntry(title="test_title", domain="test") + + mock_setup_entry = AsyncMock(return_value=True) + mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry)) + mock_entity_platform(hass, "config_flow.test", None) + + entry.add_to_hass(hass) + + # The config entry is not loaded, and is also not scheduled to load + assert await hass.config_entries.async_wait_component(entry) is False + + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + # The config entry is loaded + assert await hass.config_entries.async_wait_component(entry) is True + + +async def test_async_wait_component_startup(hass: HomeAssistant): + """Test async_wait_component for a config entry which is loaded at startup.""" + entry = MockConfigEntry(title="test_title", domain="test") + + setup_stall = asyncio.Event() + setup_started = asyncio.Event() + + async def mock_setup(hass: HomeAssistant, _) -> bool: + setup_started.set() + await setup_stall.wait() + return True + + mock_setup_entry = AsyncMock(return_value=True) + mock_integration( + hass, + MockModule("test", async_setup=mock_setup, async_setup_entry=mock_setup_entry), + ) + mock_entity_platform(hass, "config_flow.test", None) + + entry.add_to_hass(hass) + + # The config entry is not loaded, and is also not scheduled to load + assert await hass.config_entries.async_wait_component(entry) is False + + # Mark the component as scheduled to be loaded + async_set_domains_to_be_loaded(hass, {"test"}) + + # Start loading the component, including its config entries + hass.async_create_task(async_setup_component(hass, "test", {})) + await setup_started.wait() + + # The component is not yet loaded + assert "test" not in hass.config.components + + # Allow setup to proceed + setup_stall.set() + + # The component is scheduled to load, this will block until the config entry is loaded + assert await hass.config_entries.async_wait_component(entry) is True + + # The component has been loaded + assert "test" in hass.config.components