From 3b9d6f2ddebbffb02563a7fbbba67c6e8a126168 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 15 Aug 2023 10:59:42 +0200 Subject: [PATCH] Add setup function to the component loader (#98148) * Add setup function to the component loader * Update test * Setup the loader in safe mode and in check_config script --- homeassistant/bootstrap.py | 3 ++ homeassistant/loader.py | 26 ++++++++-------- homeassistant/scripts/check_config.py | 3 +- tests/common.py | 22 ++++---------- .../components/device_automation/test_init.py | 30 +++++++++---------- .../components/websocket_api/test_commands.py | 2 +- 6 files changed, 38 insertions(+), 48 deletions(-) diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 6a667884962..196a00dda7c 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -134,6 +134,7 @@ async def async_setup_hass( _LOGGER.info("Config directory: %s", runtime_config.config_dir) + loader.async_setup(hass) config_dict = None basic_setup_success = False @@ -185,6 +186,8 @@ async def async_setup_hass( hass.config.internal_url = old_config.internal_url hass.config.external_url = old_config.external_url hass.config.config_dir = old_config.config_dir + # Setup loader cache after the config dir has been set + loader.async_setup(hass) if safe_mode: _LOGGER.info("Starting in safe mode") diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 6c083b6a024..340888a2f7a 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -166,6 +166,13 @@ class Manifest(TypedDict, total=False): loggers: list[str] +def async_setup(hass: HomeAssistant) -> None: + """Set up the necessary data structures.""" + _async_mount_config_dir(hass) + hass.data[DATA_COMPONENTS] = {} + hass.data[DATA_INTEGRATIONS] = {} + + def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest: """Generate a manifest from a legacy module.""" return { @@ -802,9 +809,7 @@ class Integration: def get_component(self) -> ComponentProtocol: """Return the component.""" - cache: dict[str, ComponentProtocol] = self.hass.data.setdefault( - DATA_COMPONENTS, {} - ) + cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS] if self.domain in cache: return cache[self.domain] @@ -824,7 +829,7 @@ class Integration: def get_platform(self, platform_name: str) -> ModuleType: """Return a platform for an integration.""" - cache: dict[str, ModuleType] = self.hass.data.setdefault(DATA_COMPONENTS, {}) + cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS] full_name = f"{self.domain}.{platform_name}" if full_name in cache: return cache[full_name] @@ -883,11 +888,7 @@ async def async_get_integrations( hass: HomeAssistant, domains: Iterable[str] ) -> dict[str, Integration | Exception]: """Get integrations.""" - if (cache := hass.data.get(DATA_INTEGRATIONS)) is None: - if not _async_mount_config_dir(hass): - return {domain: IntegrationNotFound(domain) for domain in domains} - cache = hass.data[DATA_INTEGRATIONS] = {} - + cache = hass.data[DATA_INTEGRATIONS] results: dict[str, Integration | Exception] = {} needed: dict[str, asyncio.Future[None]] = {} in_progress: dict[str, asyncio.Future[None]] = {} @@ -993,10 +994,7 @@ def _load_file( comp_or_platform ] - if (cache := hass.data.get(DATA_COMPONENTS)) is None: - if not _async_mount_config_dir(hass): - return None - cache = hass.data[DATA_COMPONENTS] = {} + cache = hass.data[DATA_COMPONENTS] for path in (f"{base}.{comp_or_platform}" for base in base_paths): try: @@ -1066,7 +1064,7 @@ class Components: def __getattr__(self, comp_name: str) -> ModuleWrapper: """Fetch a component.""" # Test integration cache - integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name) + integration = self._hass.data[DATA_INTEGRATIONS].get(comp_name) if isinstance(integration, Integration): component: ComponentProtocol | None = integration.get_component() diff --git a/homeassistant/scripts/check_config.py b/homeassistant/scripts/check_config.py index 5384b86cb98..7c4a200bbc5 100644 --- a/homeassistant/scripts/check_config.py +++ b/homeassistant/scripts/check_config.py @@ -11,7 +11,7 @@ import os from typing import Any from unittest.mock import patch -from homeassistant import core +from homeassistant import core, loader from homeassistant.config import get_default_config_dir from homeassistant.config_entries import ConfigEntries from homeassistant.exceptions import HomeAssistantError @@ -232,6 +232,7 @@ def check(config_dir, secrets=False): async def async_check_config(config_dir): """Check the HA config.""" hass = core.HomeAssistant() + loader.async_setup(hass) hass.config.config_dir = config_dir hass.config_entries = ConfigEntries(hass, {}) await ar.async_load(hass) diff --git a/tests/common.py b/tests/common.py index 0431743cccf..95947719ef4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -256,6 +256,7 @@ async def async_test_home_assistant(event_loop, load_registries=True): # Load the registries entity.async_setup(hass) + loader.async_setup(hass) if load_registries: with patch( "homeassistant.helpers.storage.Store.async_load", return_value=None @@ -1339,16 +1340,10 @@ def mock_integration( integration._import_platform = mock_import_platform _LOGGER.info("Adding mock integration: %s", module.DOMAIN) - integration_cache = hass.data.get(loader.DATA_INTEGRATIONS) - if integration_cache is None: - integration_cache = hass.data[loader.DATA_INTEGRATIONS] = {} - loader._async_mount_config_dir(hass) + integration_cache = hass.data[loader.DATA_INTEGRATIONS] integration_cache[module.DOMAIN] = integration - module_cache = hass.data.get(loader.DATA_COMPONENTS) - if module_cache is None: - module_cache = hass.data[loader.DATA_COMPONENTS] = {} - loader._async_mount_config_dir(hass) + module_cache = hass.data[loader.DATA_COMPONENTS] module_cache[module.DOMAIN] = module return integration @@ -1374,15 +1369,8 @@ def mock_platform( platform_path is in form hue.config_flow. """ domain = platform_path.split(".")[0] - integration_cache = hass.data.get(loader.DATA_INTEGRATIONS) - if integration_cache is None: - integration_cache = hass.data[loader.DATA_INTEGRATIONS] = {} - loader._async_mount_config_dir(hass) - - module_cache = hass.data.get(loader.DATA_COMPONENTS) - if module_cache is None: - module_cache = hass.data[loader.DATA_COMPONENTS] = {} - loader._async_mount_config_dir(hass) + integration_cache = hass.data[loader.DATA_INTEGRATIONS] + module_cache = hass.data[loader.DATA_COMPONENTS] if domain not in integration_cache: mock_integration(hass, MockModule(domain)) diff --git a/tests/components/device_automation/test_init.py b/tests/components/device_automation/test_init.py index 65fee1053ae..74150af67ae 100644 --- a/tests/components/device_automation/test_init.py +++ b/tests/components/device_automation/test_init.py @@ -304,7 +304,7 @@ async def test_websocket_get_action_capabilities( return {"extra_fields": vol.Schema({vol.Optional("code"): str})} return {} - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_action"] module.async_get_action_capabilities = _async_get_action_capabilities @@ -406,7 +406,7 @@ async def test_websocket_get_action_capabilities_bad_action( await async_setup_component(hass, "device_automation", {}) expected_capabilities = {} - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_action"] module.async_get_action_capabilities = Mock( side_effect=InvalidDeviceAutomationConfig @@ -459,7 +459,7 @@ async def test_websocket_get_condition_capabilities( """List condition capabilities.""" return await toggle_entity.async_get_condition_capabilities(hass, config) - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_condition"] module.async_get_condition_capabilities = _async_get_condition_capabilities @@ -569,7 +569,7 @@ async def test_websocket_get_condition_capabilities_bad_condition( await async_setup_component(hass, "device_automation", {}) expected_capabilities = {} - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_condition"] module.async_get_condition_capabilities = Mock( side_effect=InvalidDeviceAutomationConfig @@ -747,7 +747,7 @@ async def test_websocket_get_trigger_capabilities( """List trigger capabilities.""" return await toggle_entity.async_get_trigger_capabilities(hass, config) - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_trigger"] module.async_get_trigger_capabilities = _async_get_trigger_capabilities @@ -857,7 +857,7 @@ async def test_websocket_get_trigger_capabilities_bad_trigger( await async_setup_component(hass, "device_automation", {}) expected_capabilities = {} - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_trigger"] module.async_get_trigger_capabilities = Mock( side_effect=InvalidDeviceAutomationConfig @@ -912,7 +912,7 @@ async def test_automation_with_device_action( ) -> None: """Test automation with a device action.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_action"] module.async_call_action_from_config = AsyncMock() @@ -949,7 +949,7 @@ async def test_automation_with_dynamically_validated_action( ) -> None: """Test device automation with an action which is dynamically validated.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_action"] module.async_validate_action_config = AsyncMock() @@ -1003,7 +1003,7 @@ async def test_automation_with_device_condition( ) -> None: """Test automation with a device condition.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_condition"] module.async_condition_from_config = Mock() @@ -1037,7 +1037,7 @@ async def test_automation_with_dynamically_validated_condition( ) -> None: """Test device automation with a condition which is dynamically validated.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_condition"] module.async_validate_condition_config = AsyncMock() @@ -1102,7 +1102,7 @@ async def test_automation_with_device_trigger( ) -> None: """Test automation with a device trigger.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_trigger"] module.async_attach_trigger = AsyncMock() @@ -1136,7 +1136,7 @@ async def test_automation_with_dynamically_validated_trigger( ) -> None: """Test device automation with a trigger which is dynamically validated.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_trigger"] module.async_attach_trigger = AsyncMock() module.async_validate_trigger_config = AsyncMock(wraps=lambda hass, config: config) @@ -1457,7 +1457,7 @@ async def test_automation_with_unknown_device( ) -> None: """Test device automation with a trigger with an unknown device.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_trigger"] module.async_validate_trigger_config = AsyncMock() @@ -1492,7 +1492,7 @@ async def test_automation_with_device_wrong_domain( ) -> None: """Test device automation where the device doesn't have the right config entry.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_trigger"] module.async_validate_trigger_config = AsyncMock() @@ -1534,7 +1534,7 @@ async def test_automation_with_device_component_not_loaded( ) -> None: """Test device automation where the device's config entry is not loaded.""" - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_trigger"] module.async_validate_trigger_config = AsyncMock() module.async_attach_trigger = AsyncMock() diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 3a68bbd88d3..85c0ac62b25 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -1810,7 +1810,7 @@ async def test_execute_script_with_dynamically_validated_action( ws_client = await hass_ws_client(hass) - module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + module_cache = hass.data[loader.DATA_COMPONENTS] module = module_cache["fake_integration.device_action"] module.async_call_action_from_config = AsyncMock() module.async_validate_action_config = AsyncMock(