From 170e6bdcab30e416fcaa990ba2e6a34e7d91a81c Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 9 Apr 2025 15:27:52 +0200 Subject: [PATCH] Protect hass data keys in setup.py (#142589) --- homeassistant/config_entries.py | 8 +-- homeassistant/setup.py | 62 ++++++++++-------- tests/test_setup.py | 107 ++++++++++++++++++++++---------- 3 files changed, 110 insertions(+), 67 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index ef1865da4be..b47815c9aa9 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -72,12 +72,12 @@ from .helpers.json import json_bytes, json_bytes_sorted, json_fragment from .helpers.typing import UNDEFINED, ConfigType, DiscoveryInfoType, UndefinedType from .loader import async_suggest_report_issue from .setup import ( - DATA_SETUP_DONE, SetupPhases, async_pause_setup, async_process_deps_reqs, async_setup_component, async_start_setup, + async_wait_component, ) from .util import ulid as ulid_util from .util.async_ import create_eager_task @@ -2701,11 +2701,7 @@ class ConfigEntries: Config entries which are created after Home Assistant is started can't be waited for, the function will just return if the config entry is loaded or not. """ - setup_done = self.hass.data.get(DATA_SETUP_DONE, {}) - if setup_future := setup_done.get(entry.domain): - await setup_future - # The component was not loaded. - if entry.domain not in self.hass.config.components: + if not await async_wait_component(self.hass, entry.domain): return False return entry.state is ConfigEntryState.LOADED diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 76061b72b73..39f0a7656f3 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -45,36 +45,36 @@ _LOGGER = logging.getLogger(__name__) ATTR_COMPONENT: Final = "component" -# DATA_SETUP is a dict, indicating domains which are currently +# _DATA_SETUP is a dict, indicating domains which are currently # being setup or which failed to setup: -# - Tasks are added to DATA_SETUP by `async_setup_component`, the key is the domain +# - Tasks are added to _DATA_SETUP by `async_setup_component`, the key is the domain # being setup and the Task is the `_async_setup_component` helper. -# - Tasks are removed from DATA_SETUP if setup was successful, that is, +# - Tasks are removed from _DATA_SETUP if setup was successful, that is, # the task returned True. -DATA_SETUP: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_tasks") +_DATA_SETUP: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_tasks") -# DATA_SETUP_DONE is a dict, indicating components which will be setup: -# - Events are added to DATA_SETUP_DONE during bootstrap by +# _DATA_SETUP_DONE is a dict, indicating components which will be setup: +# - Events are added to _DATA_SETUP_DONE during bootstrap by # async_set_domains_to_be_loaded, the key is the domain which will be loaded. -# - Events are set and removed from DATA_SETUP_DONE when async_setup_component +# - Events are set and removed from _DATA_SETUP_DONE when async_setup_component # is finished, regardless of if the setup was successful or not. -DATA_SETUP_DONE: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_done") +_DATA_SETUP_DONE: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_done") -# DATA_SETUP_STARTED is a dict, indicating when an attempt +# _DATA_SETUP_STARTED is a dict, indicating when an attempt # to setup a component started. -DATA_SETUP_STARTED: HassKey[dict[tuple[str, str | None], float]] = HassKey( +_DATA_SETUP_STARTED: HassKey[dict[tuple[str, str | None], float]] = HassKey( "setup_started" ) -# DATA_SETUP_TIME is a defaultdict, indicating how time was spent +# _DATA_SETUP_TIME is a defaultdict, indicating how time was spent # setting up a component. -DATA_SETUP_TIME: HassKey[ +_DATA_SETUP_TIME: HassKey[ defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]] ] = HassKey("setup_time") -DATA_DEPS_REQS: HassKey[set[str]] = HassKey("deps_reqs_processed") +_DATA_DEPS_REQS: HassKey[set[str]] = HassKey("deps_reqs_processed") -DATA_PERSISTENT_ERRORS: HassKey[dict[str, str | None]] = HassKey( +_DATA_PERSISTENT_ERRORS: HassKey[dict[str, str | None]] = HassKey( "bootstrap_persistent_errors" ) @@ -104,8 +104,8 @@ def async_notify_setup_error( # pylint: disable-next=import-outside-toplevel from .components import persistent_notification - if (errors := hass.data.get(DATA_PERSISTENT_ERRORS)) is None: - errors = hass.data[DATA_PERSISTENT_ERRORS] = {} + if (errors := hass.data.get(_DATA_PERSISTENT_ERRORS)) is None: + errors = hass.data[_DATA_PERSISTENT_ERRORS] = {} errors[component] = errors.get(component) or display_link @@ -131,8 +131,8 @@ def async_set_domains_to_be_loaded(hass: core.HomeAssistant, domains: set[str]) - Properly handle after_dependencies. - Keep track of domains which will load but have not yet finished loading """ - setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {}) - setup_futures = hass.data.setdefault(DATA_SETUP, {}) + setup_done_futures = hass.data.setdefault(_DATA_SETUP_DONE, {}) + setup_futures = hass.data.setdefault(_DATA_SETUP, {}) old_domains = set(setup_futures) | set(setup_done_futures) | hass.config.components if overlap := old_domains & domains: _LOGGER.debug("Domains to be loaded %s already loaded or pending", overlap) @@ -158,8 +158,8 @@ async def async_setup_component( if domain in hass.config.components: return True - setup_futures = hass.data.setdefault(DATA_SETUP, {}) - setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {}) + setup_futures = hass.data.setdefault(_DATA_SETUP, {}) + setup_done_futures = hass.data.setdefault(_DATA_SETUP_DONE, {}) if existing_setup_future := setup_futures.get(domain): return await existing_setup_future @@ -200,7 +200,7 @@ async def _async_process_dependencies( Returns a list of dependencies which failed to set up. """ - setup_futures = hass.data.setdefault(DATA_SETUP, {}) + setup_futures = hass.data.setdefault(_DATA_SETUP, {}) dependencies_tasks: dict[str, asyncio.Future[bool]] = {} @@ -216,7 +216,7 @@ async def _async_process_dependencies( ) dependencies_tasks[dep] = fut - to_be_loaded = hass.data.get(DATA_SETUP_DONE, {}) + to_be_loaded = hass.data.get(_DATA_SETUP_DONE, {}) # We don't want to just wait for the futures from `to_be_loaded` here. # We want to ensure that our after_dependencies are always actually # scheduled to be set up, as if for whatever reason they had not been, @@ -483,7 +483,7 @@ async def _async_setup_component( ) # Cleanup - hass.data[DATA_SETUP].pop(domain, None) + hass.data[_DATA_SETUP].pop(domain, None) hass.bus.async_fire_internal( EVENT_COMPONENT_LOADED, EventComponentLoaded(component=domain) @@ -573,8 +573,8 @@ async def async_process_deps_reqs( Module is a Python module of either a component or platform. """ - if (processed := hass.data.get(DATA_DEPS_REQS)) is None: - processed = hass.data[DATA_DEPS_REQS] = set() + if (processed := hass.data.get(_DATA_DEPS_REQS)) is None: + processed = hass.data[_DATA_DEPS_REQS] = set() elif integration.domain in processed: return @@ -689,7 +689,7 @@ class SetupPhases(StrEnum): """Wait time for the packages to import.""" -@singleton.singleton(DATA_SETUP_STARTED) +@singleton.singleton(_DATA_SETUP_STARTED) def _setup_started( hass: core.HomeAssistant, ) -> dict[tuple[str, str | None], float]: @@ -732,7 +732,7 @@ def async_pause_setup(hass: core.HomeAssistant, phase: SetupPhases) -> Generator ) -@singleton.singleton(DATA_SETUP_TIME) +@singleton.singleton(_DATA_SETUP_TIME) def _setup_times( hass: core.HomeAssistant, ) -> defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]]: @@ -832,3 +832,11 @@ def async_get_domain_setup_times( ) -> Mapping[str | None, dict[SetupPhases, float]]: """Return timing data for each integration.""" return _setup_times(hass).get(domain, {}) + + +async def async_wait_component(hass: HomeAssistant, domain: str) -> bool: + """Wait until a component is set up if pending, then return if it is set up.""" + setup_done = hass.data.get(_DATA_SETUP_DONE, {}) + if setup_future := setup_done.get(domain): + await setup_future + return domain in hass.config.components diff --git a/tests/test_setup.py b/tests/test_setup.py index 084b657a2f2..96a13017430 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -57,21 +57,21 @@ async def test_validate_component_config(hass: HomeAssistant) -> None: with assert_setup_component(0): assert not await setup.async_setup_component(hass, "comp_conf", {}) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) with assert_setup_component(0): assert not await setup.async_setup_component( hass, "comp_conf", {"comp_conf": None} ) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) with assert_setup_component(0): assert not await setup.async_setup_component( hass, "comp_conf", {"comp_conf": {}} ) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) with assert_setup_component(0): assert not await setup.async_setup_component( @@ -80,7 +80,7 @@ async def test_validate_component_config(hass: HomeAssistant) -> None: {"comp_conf": {"hello": "world", "invalid": "extra"}}, ) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) with assert_setup_component(1): assert await setup.async_setup_component( @@ -111,7 +111,7 @@ async def test_validate_platform_config( {"platform_conf": {"platform": "not_existing", "hello": "world"}}, ) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("platform_conf") with assert_setup_component(1): @@ -121,7 +121,7 @@ async def test_validate_platform_config( {"platform_conf": {"platform": "whatever", "hello": "world"}}, ) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("platform_conf") with assert_setup_component(1): @@ -131,7 +131,7 @@ async def test_validate_platform_config( {"platform_conf": [{"platform": "whatever", "hello": "world"}]}, ) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("platform_conf") # Any falsey platform config will be ignored (None, {}, etc) @@ -240,7 +240,7 @@ async def test_validate_platform_config_4(hass: HomeAssistant) -> None: }, ) - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("platform_conf") @@ -345,7 +345,7 @@ async def test_component_not_setup_missing_dependencies(hass: HomeAssistant) -> assert not await setup.async_setup_component(hass, "comp", {}) assert "comp" not in hass.config.components - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) mock_integration(hass, MockModule("comp2", dependencies=deps)) mock_integration(hass, MockModule("maybe_existing")) @@ -443,8 +443,8 @@ async def test_component_exception_setup(hass: HomeAssistant) -> None: mock_integration(hass, MockModule(domain, setup=exception_setup)) assert not await setup.async_setup_component(hass, domain, {}) - assert domain in hass.data[setup.DATA_SETUP] - assert domain not in hass.data[setup.DATA_SETUP_DONE] + assert domain in hass.data[setup._DATA_SETUP] + assert domain not in hass.data[setup._DATA_SETUP_DONE] assert domain not in hass.config.components @@ -463,8 +463,8 @@ async def test_component_base_exception_setup(hass: HomeAssistant) -> None: await setup.async_setup_component(hass, "comp", {}) assert str(exc_info.value) == "fail!" - assert domain in hass.data[setup.DATA_SETUP] - assert domain not in hass.data[setup.DATA_SETUP_DONE] + assert domain in hass.data[setup._DATA_SETUP] + assert domain not in hass.data[setup._DATA_SETUP_DONE] assert domain not in hass.config.components @@ -477,12 +477,12 @@ async def test_set_domains_to_be_loaded(hass: HomeAssistant) -> None: domains = {domain_good, domain_bad, domain_exception, domain_base_exception} setup.async_set_domains_to_be_loaded(hass, domains) - assert set(hass.data[setup.DATA_SETUP_DONE]) == domains - setup_done = dict(hass.data[setup.DATA_SETUP_DONE]) + assert set(hass.data[setup._DATA_SETUP_DONE]) == domains + setup_done = dict(hass.data[setup._DATA_SETUP_DONE]) # Calling async_set_domains_to_be_loaded again should not create new futures setup.async_set_domains_to_be_loaded(hass, domains) - assert setup_done == hass.data[setup.DATA_SETUP_DONE] + assert setup_done == hass.data[setup._DATA_SETUP_DONE] def good_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Success.""" @@ -515,8 +515,8 @@ async def test_set_domains_to_be_loaded(hass: HomeAssistant) -> None: await setup.async_setup_component(hass, domain_base_exception, {}) # Check the result of the setup - assert not hass.data[setup.DATA_SETUP_DONE] - assert set(hass.data[setup.DATA_SETUP]) == { + assert not hass.data[setup._DATA_SETUP_DONE] + assert set(hass.data[setup._DATA_SETUP]) == { domain_bad, domain_exception, domain_base_exception, @@ -525,7 +525,7 @@ async def test_set_domains_to_be_loaded(hass: HomeAssistant) -> None: # Calling async_set_domains_to_be_loaded again should not create any new futures setup.async_set_domains_to_be_loaded(hass, domains) - assert not hass.data[setup.DATA_SETUP_DONE] + assert not hass.data[setup._DATA_SETUP_DONE] async def test_component_setup_after_dependencies(hass: HomeAssistant) -> None: @@ -608,7 +608,7 @@ async def test_platform_specific_config_validation(hass: HomeAssistant) -> None: assert mock_setup.call_count == 0 assert len(mock_notify.mock_calls) == 1 - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("switch") with ( @@ -630,7 +630,7 @@ async def test_platform_specific_config_validation(hass: HomeAssistant) -> None: assert mock_setup.call_count == 0 assert len(mock_notify.mock_calls) == 1 - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("switch") with ( @@ -656,7 +656,7 @@ async def test_disable_component_if_invalid_return(hass: HomeAssistant) -> None: assert not await setup.async_setup_component(hass, "disabled_component", {}) assert "disabled_component" not in hass.config.components - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) mock_integration( hass, MockModule("disabled_component", setup=lambda hass, config: False), @@ -665,7 +665,7 @@ async def test_disable_component_if_invalid_return(hass: HomeAssistant) -> None: assert not await setup.async_setup_component(hass, "disabled_component", {}) assert "disabled_component" not in hass.config.components - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) mock_integration( hass, MockModule("disabled_component", setup=lambda hass, config: True) ) @@ -939,7 +939,7 @@ async def test_integration_only_setup_entry(hass: HomeAssistant) -> None: async def test_async_start_setup_running(hass: HomeAssistant) -> None: """Test setup started context manager does nothing when running.""" assert hass.state is CoreState.running - setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) + setup_started = hass.data.setdefault(setup._DATA_SETUP_STARTED, {}) with setup.async_start_setup( hass, integration="august", phase=setup.SetupPhases.SETUP @@ -952,7 +952,7 @@ async def test_async_start_setup_config_entry( ) -> None: """Test setup started keeps track of setup times with a config entry.""" hass.set_state(CoreState.not_running) - setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) + setup_started = hass.data.setdefault(setup._DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) with setup.async_start_setup( @@ -1062,7 +1062,7 @@ async def test_async_start_setup_config_entry_late_platform( ) -> None: """Test setup started tracks config entry time with a late platform load.""" hass.set_state(CoreState.not_running) - setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) + setup_started = hass.data.setdefault(setup._DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) with setup.async_start_setup( @@ -1116,7 +1116,7 @@ async def test_async_start_setup_config_entry_platform_wait( ) -> None: """Test setup started tracks wait time when a platform loads inside of config entry setup.""" hass.set_state(CoreState.not_running) - setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) + setup_started = hass.data.setdefault(setup._DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) with setup.async_start_setup( @@ -1158,7 +1158,7 @@ async def test_async_start_setup_config_entry_platform_wait( async def test_async_start_setup_top_level_yaml(hass: HomeAssistant) -> None: """Test setup started context manager keeps track of setup times with modern yaml.""" hass.set_state(CoreState.not_running) - setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) + setup_started = hass.data.setdefault(setup._DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) with setup.async_start_setup( @@ -1174,7 +1174,7 @@ async def test_async_start_setup_top_level_yaml(hass: HomeAssistant) -> None: async def test_async_start_setup_platform_integration(hass: HomeAssistant) -> None: """Test setup started keeps track of setup times a platform integration.""" hass.set_state(CoreState.not_running) - setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) + setup_started = hass.data.setdefault(setup._DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) with setup.async_start_setup( @@ -1208,7 +1208,7 @@ async def test_async_start_setup_legacy_platform_integration( ) -> None: """Test setup started keeps track of setup times for a legacy platform integration.""" hass.set_state(CoreState.not_running) - setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) + setup_started = hass.data.setdefault(setup._DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) with setup.async_start_setup( @@ -1330,7 +1330,7 @@ async def test_setup_config_entry_from_yaml( assert await setup.async_setup_component(hass, "test_integration_only_entry", {}) assert expected_warning not in caplog.text caplog.clear() - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("test_integration_only_entry") # There should be a warning, but setup should not fail @@ -1339,7 +1339,7 @@ async def test_setup_config_entry_from_yaml( ) assert expected_warning in caplog.text caplog.clear() - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("test_integration_only_entry") # There should be a warning, but setup should not fail @@ -1348,7 +1348,7 @@ async def test_setup_config_entry_from_yaml( ) assert expected_warning in caplog.text caplog.clear() - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("test_integration_only_entry") # There should be a warning, but setup should not fail @@ -1359,7 +1359,7 @@ async def test_setup_config_entry_from_yaml( ) assert expected_warning in caplog.text caplog.clear() - hass.data.pop(setup.DATA_SETUP) + hass.data.pop(setup._DATA_SETUP) hass.config.components.remove("test_integration_only_entry") @@ -1408,3 +1408,42 @@ async def test_async_prepare_setup_platform( await setup.async_prepare_setup_platform(hass, {}, "button", "test") is None ) assert button_platform is not None + + +async def test_async_wait_component(hass: HomeAssistant) -> None: + """Test async_wait_component.""" + setup_stall = asyncio.Event() + setup_started = asyncio.Event() + + async def mock_setup(hass: HomeAssistant, _) -> bool: + setup_started.set() + await setup_stall.wait() + return True + + mock_integration(hass, MockModule("test", async_setup=mock_setup)) + + # The integration not loaded, and is also not scheduled to load + assert await setup.async_wait_component(hass, "test") is False + + # Mark the component as scheduled to be loaded + setup.async_set_domains_to_be_loaded(hass, {"test"}) + + # Start loading the component, including its config entries + hass.async_create_task(setup.async_setup_component(hass, "test", {})) + await setup_started.wait() + + # The component is not yet loaded + assert "test" not in hass.config.components + + # Allow setup to proceed + setup_stall.set() + + # The component is scheduled to load, this will block until the config entry is loaded + assert await setup.async_wait_component(hass, "test") is True + + # The component has been loaded + assert "test" in hass.config.components + + # Clear the event, then call again to make sure we don't block + setup_stall.clear() + assert await setup.async_wait_component(hass, "test") is True