Avoid multiple executor jobs with concurrent calls to async_get_component (#112155)

This commit is contained in:
J. Nick Koston 2024-03-03 20:22:31 -10:00 committed by GitHub
parent 99414d8b85
commit 943996b60b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 8 deletions

View File

@ -661,6 +661,7 @@ class Integration:
self._all_dependencies_resolved = True self._all_dependencies_resolved = True
self._all_dependencies = set() self._all_dependencies = set()
self._component_future: asyncio.Future[ComponentProtocol] | None = None
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {} self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path) _LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
@ -842,34 +843,60 @@ class Integration:
and will check if import_executor is set and load it in the executor, and will check if import_executor is set and load it in the executor,
otherwise it will load it in the event loop. otherwise it will load it in the event loop.
""" """
if self._component_future:
return await self._component_future
if debug := _LOGGER.isEnabledFor(logging.DEBUG): if debug := _LOGGER.isEnabledFor(logging.DEBUG):
start = time.perf_counter() start = time.perf_counter()
domain = self.domain
# Some integrations fail on import because they call functions incorrectly. # Some integrations fail on import because they call functions incorrectly.
# So we do it before validating config to catch these errors. # So we do it before validating config to catch these errors.
load_executor = self.import_executor and ( load_executor = self.import_executor and (
self.pkg_path not in sys.modules self.pkg_path not in sys.modules
or (self.config_flow and f"{self.pkg_path}.config_flow" not in sys.modules) or (self.config_flow and f"{self.pkg_path}.config_flow" not in sys.modules)
) )
if load_executor: if not load_executor:
comp = self.get_component()
if debug:
_LOGGER.debug(
"Component %s import took %.3f seconds (loaded_executor=False)",
self.domain,
time.perf_counter() - start,
)
return comp
self._component_future = self.hass.loop.create_future()
try:
try: try:
comp = await self.hass.async_add_import_executor_job(self.get_component) comp = await self.hass.async_add_import_executor_job(self.get_component)
except ImportError as ex: except ImportError as ex:
load_executor = False load_executor = False
_LOGGER.debug("Failed to import %s in executor", domain, exc_info=ex) _LOGGER.debug(
"Failed to import %s in executor", self.domain, exc_info=ex
)
# If importing in the executor deadlocks because there is a circular # If importing in the executor deadlocks because there is a circular
# dependency, we fall back to the event loop. # dependency, we fall back to the event loop.
comp = self.get_component() comp = self.get_component()
else: self._component_future.set_result(comp)
comp = self.get_component() except BaseException as ex:
self._component_future.set_exception(ex)
with suppress(BaseException):
# Set the exception retrieved flag on the future since
# it will never be retrieved unless there
# are concurrent calls to async_get_component
self._component_future.result()
raise
finally:
self._component_future = None
if debug: if debug:
_LOGGER.debug( _LOGGER.debug(
"Component %s import took %.3f seconds (loaded_executor=%s)", "Component %s import took %.3f seconds (loaded_executor=%s)",
domain, self.domain,
time.perf_counter() - start, time.perf_counter() - start,
load_executor, load_executor,
) )
return comp return comp
def get_component(self) -> ComponentProtocol: def get_component(self) -> ComponentProtocol:

View File

@ -2,6 +2,7 @@
import asyncio import asyncio
import os import os
import sys import sys
import threading
from typing import Any from typing import Any
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
@ -1086,8 +1087,8 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules(
assert integration.import_executor is True assert integration.import_executor is True
assert integration.config_flow is True assert integration.config_flow is True
assert "executor_import" not in hass.config.components assert "test_package_loaded_executor" not in hass.config.components
assert "executor_import.config_flow" not in hass.config.components assert "test_package_loaded_executor.config_flow" not in hass.config.components
config_flow_module_name = f"{integration.pkg_path}.config_flow" config_flow_module_name = f"{integration.pkg_path}.config_flow"
module_mock = MagicMock() module_mock = MagicMock()
@ -1133,6 +1134,64 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules(
assert module is module_mock assert module is module_mock
async def test_async_get_component_concurrent_loads(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
enable_custom_integrations: None,
) -> None:
"""Verify async_get_component waits if the first load if called again when still in progress."""
integration = await loader.async_get_integration(
hass, "test_package_loaded_executor"
)
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
assert integration.import_executor is True
assert integration.config_flow is True
assert "test_package_loaded_executor" not in hass.config.components
assert "test_package_loaded_executor.config_flow" not in hass.config.components
config_flow_module_name = f"{integration.pkg_path}.config_flow"
module_mock = MagicMock()
config_flow_module_mock = MagicMock()
imports = []
start_event = threading.Event()
import_event = asyncio.Event()
def import_module(name: str) -> Any:
hass.loop.call_soon_threadsafe(import_event.set)
imports.append(name)
start_event.wait()
if name == integration.pkg_path:
return module_mock
if name == config_flow_module_name:
return config_flow_module_mock
raise ImportError
modules_without_integration = {
k: v
for k, v in sys.modules.items()
if k != config_flow_module_name and k != integration.pkg_path
}
with patch.dict(
"sys.modules",
{**modules_without_integration},
clear=True,
), patch("homeassistant.loader.importlib.import_module", import_module):
load_task1 = asyncio.create_task(integration.async_get_component())
load_task2 = asyncio.create_task(integration.async_get_component())
await import_event.wait() # make sure the import is started
assert not integration._component_future.done()
start_event.set()
comp1 = await load_task1
comp2 = await load_task2
assert integration._component_future is None
assert comp1 is module_mock
assert comp2 is module_mock
assert imports == [integration.pkg_path, config_flow_module_name]
async def test_async_get_component_deadlock_fallback( async def test_async_get_component_deadlock_fallback(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None: