diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 0d03c4f81eb..33ccd5615a8 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -661,6 +661,7 @@ class Integration: self._all_dependencies_resolved = True self._all_dependencies = set() + self._component_future: asyncio.Future[ComponentProtocol] | None = None self._import_futures: dict[str, asyncio.Future[ModuleType]] = {} _LOGGER.info("Loaded %s from %s", self.domain, pkg_path) @@ -842,34 +843,60 @@ class Integration: and will check if import_executor is set and load it in the executor, otherwise it will load it in the event loop. """ + if self._component_future: + return await self._component_future + if debug := _LOGGER.isEnabledFor(logging.DEBUG): start = time.perf_counter() - domain = self.domain + # Some integrations fail on import because they call functions incorrectly. # So we do it before validating config to catch these errors. load_executor = self.import_executor and ( self.pkg_path not in sys.modules or (self.config_flow and f"{self.pkg_path}.config_flow" not in sys.modules) ) - if load_executor: + if not load_executor: + comp = self.get_component() + if debug: + _LOGGER.debug( + "Component %s import took %.3f seconds (loaded_executor=False)", + self.domain, + time.perf_counter() - start, + ) + return comp + + self._component_future = self.hass.loop.create_future() + try: try: comp = await self.hass.async_add_import_executor_job(self.get_component) except ImportError as ex: load_executor = False - _LOGGER.debug("Failed to import %s in executor", domain, exc_info=ex) + _LOGGER.debug( + "Failed to import %s in executor", self.domain, exc_info=ex + ) # If importing in the executor deadlocks because there is a circular # dependency, we fall back to the event loop. comp = self.get_component() - else: - comp = self.get_component() + self._component_future.set_result(comp) + except BaseException as ex: + self._component_future.set_exception(ex) + with suppress(BaseException): + # Set the exception retrieved flag on the future since + # it will never be retrieved unless there + # are concurrent calls to async_get_component + self._component_future.result() + raise + finally: + self._component_future = None if debug: _LOGGER.debug( "Component %s import took %.3f seconds (loaded_executor=%s)", - domain, + self.domain, time.perf_counter() - start, load_executor, ) + return comp def get_component(self) -> ComponentProtocol: diff --git a/tests/test_loader.py b/tests/test_loader.py index a70bf7d4e3f..8400adca5c4 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -2,6 +2,7 @@ import asyncio import os import sys +import threading from typing import Any from unittest.mock import MagicMock, Mock, patch @@ -1086,8 +1087,8 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules( assert integration.import_executor is True assert integration.config_flow is True - assert "executor_import" not in hass.config.components - assert "executor_import.config_flow" not in hass.config.components + assert "test_package_loaded_executor" not in hass.config.components + assert "test_package_loaded_executor.config_flow" not in hass.config.components config_flow_module_name = f"{integration.pkg_path}.config_flow" module_mock = MagicMock() @@ -1133,6 +1134,64 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules( assert module is module_mock +async def test_async_get_component_concurrent_loads( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + enable_custom_integrations: None, +) -> None: + """Verify async_get_component waits if the first load if called again when still in progress.""" + integration = await loader.async_get_integration( + hass, "test_package_loaded_executor" + ) + assert integration.pkg_path == "custom_components.test_package_loaded_executor" + assert integration.import_executor is True + assert integration.config_flow is True + + assert "test_package_loaded_executor" not in hass.config.components + assert "test_package_loaded_executor.config_flow" not in hass.config.components + + config_flow_module_name = f"{integration.pkg_path}.config_flow" + module_mock = MagicMock() + config_flow_module_mock = MagicMock() + imports = [] + start_event = threading.Event() + import_event = asyncio.Event() + + def import_module(name: str) -> Any: + hass.loop.call_soon_threadsafe(import_event.set) + imports.append(name) + start_event.wait() + if name == integration.pkg_path: + return module_mock + if name == config_flow_module_name: + return config_flow_module_mock + raise ImportError + + modules_without_integration = { + k: v + for k, v in sys.modules.items() + if k != config_flow_module_name and k != integration.pkg_path + } + with patch.dict( + "sys.modules", + {**modules_without_integration}, + clear=True, + ), patch("homeassistant.loader.importlib.import_module", import_module): + load_task1 = asyncio.create_task(integration.async_get_component()) + load_task2 = asyncio.create_task(integration.async_get_component()) + await import_event.wait() # make sure the import is started + assert not integration._component_future.done() + start_event.set() + comp1 = await load_task1 + comp2 = await load_task2 + assert integration._component_future is None + + assert comp1 is module_mock + assert comp2 is module_mock + + assert imports == [integration.pkg_path, config_flow_module_name] + + async def test_async_get_component_deadlock_fallback( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: