mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 10:47:10 +00:00
Avoid multiple executor jobs with concurrent calls to async_get_component (#112155)
This commit is contained in:
parent
99414d8b85
commit
943996b60b
@ -661,6 +661,7 @@ class Integration:
|
|||||||
self._all_dependencies_resolved = True
|
self._all_dependencies_resolved = True
|
||||||
self._all_dependencies = set()
|
self._all_dependencies = set()
|
||||||
|
|
||||||
|
self._component_future: asyncio.Future[ComponentProtocol] | None = None
|
||||||
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
|
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
|
||||||
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
|
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
|
||||||
|
|
||||||
@ -842,34 +843,60 @@ class Integration:
|
|||||||
and will check if import_executor is set and load it in the executor,
|
and will check if import_executor is set and load it in the executor,
|
||||||
otherwise it will load it in the event loop.
|
otherwise it will load it in the event loop.
|
||||||
"""
|
"""
|
||||||
|
if self._component_future:
|
||||||
|
return await self._component_future
|
||||||
|
|
||||||
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
domain = self.domain
|
|
||||||
# Some integrations fail on import because they call functions incorrectly.
|
# Some integrations fail on import because they call functions incorrectly.
|
||||||
# So we do it before validating config to catch these errors.
|
# So we do it before validating config to catch these errors.
|
||||||
load_executor = self.import_executor and (
|
load_executor = self.import_executor and (
|
||||||
self.pkg_path not in sys.modules
|
self.pkg_path not in sys.modules
|
||||||
or (self.config_flow and f"{self.pkg_path}.config_flow" not in sys.modules)
|
or (self.config_flow and f"{self.pkg_path}.config_flow" not in sys.modules)
|
||||||
)
|
)
|
||||||
if load_executor:
|
if not load_executor:
|
||||||
|
comp = self.get_component()
|
||||||
|
if debug:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Component %s import took %.3f seconds (loaded_executor=False)",
|
||||||
|
self.domain,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
|
return comp
|
||||||
|
|
||||||
|
self._component_future = self.hass.loop.create_future()
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
comp = await self.hass.async_add_import_executor_job(self.get_component)
|
comp = await self.hass.async_add_import_executor_job(self.get_component)
|
||||||
except ImportError as ex:
|
except ImportError as ex:
|
||||||
load_executor = False
|
load_executor = False
|
||||||
_LOGGER.debug("Failed to import %s in executor", domain, exc_info=ex)
|
_LOGGER.debug(
|
||||||
|
"Failed to import %s in executor", self.domain, exc_info=ex
|
||||||
|
)
|
||||||
# If importing in the executor deadlocks because there is a circular
|
# If importing in the executor deadlocks because there is a circular
|
||||||
# dependency, we fall back to the event loop.
|
# dependency, we fall back to the event loop.
|
||||||
comp = self.get_component()
|
comp = self.get_component()
|
||||||
else:
|
self._component_future.set_result(comp)
|
||||||
comp = self.get_component()
|
except BaseException as ex:
|
||||||
|
self._component_future.set_exception(ex)
|
||||||
|
with suppress(BaseException):
|
||||||
|
# Set the exception retrieved flag on the future since
|
||||||
|
# it will never be retrieved unless there
|
||||||
|
# are concurrent calls to async_get_component
|
||||||
|
self._component_future.result()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._component_future = None
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Component %s import took %.3f seconds (loaded_executor=%s)",
|
"Component %s import took %.3f seconds (loaded_executor=%s)",
|
||||||
domain,
|
self.domain,
|
||||||
time.perf_counter() - start,
|
time.perf_counter() - start,
|
||||||
load_executor,
|
load_executor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return comp
|
return comp
|
||||||
|
|
||||||
def get_component(self) -> ComponentProtocol:
|
def get_component(self) -> ComponentProtocol:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
@ -1086,8 +1087,8 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules(
|
|||||||
assert integration.import_executor is True
|
assert integration.import_executor is True
|
||||||
assert integration.config_flow is True
|
assert integration.config_flow is True
|
||||||
|
|
||||||
assert "executor_import" not in hass.config.components
|
assert "test_package_loaded_executor" not in hass.config.components
|
||||||
assert "executor_import.config_flow" not in hass.config.components
|
assert "test_package_loaded_executor.config_flow" not in hass.config.components
|
||||||
|
|
||||||
config_flow_module_name = f"{integration.pkg_path}.config_flow"
|
config_flow_module_name = f"{integration.pkg_path}.config_flow"
|
||||||
module_mock = MagicMock()
|
module_mock = MagicMock()
|
||||||
@ -1133,6 +1134,64 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules(
|
|||||||
assert module is module_mock
|
assert module is module_mock
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_component_concurrent_loads(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
enable_custom_integrations: None,
|
||||||
|
) -> None:
|
||||||
|
"""Verify async_get_component waits if the first load if called again when still in progress."""
|
||||||
|
integration = await loader.async_get_integration(
|
||||||
|
hass, "test_package_loaded_executor"
|
||||||
|
)
|
||||||
|
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
|
||||||
|
assert integration.import_executor is True
|
||||||
|
assert integration.config_flow is True
|
||||||
|
|
||||||
|
assert "test_package_loaded_executor" not in hass.config.components
|
||||||
|
assert "test_package_loaded_executor.config_flow" not in hass.config.components
|
||||||
|
|
||||||
|
config_flow_module_name = f"{integration.pkg_path}.config_flow"
|
||||||
|
module_mock = MagicMock()
|
||||||
|
config_flow_module_mock = MagicMock()
|
||||||
|
imports = []
|
||||||
|
start_event = threading.Event()
|
||||||
|
import_event = asyncio.Event()
|
||||||
|
|
||||||
|
def import_module(name: str) -> Any:
|
||||||
|
hass.loop.call_soon_threadsafe(import_event.set)
|
||||||
|
imports.append(name)
|
||||||
|
start_event.wait()
|
||||||
|
if name == integration.pkg_path:
|
||||||
|
return module_mock
|
||||||
|
if name == config_flow_module_name:
|
||||||
|
return config_flow_module_mock
|
||||||
|
raise ImportError
|
||||||
|
|
||||||
|
modules_without_integration = {
|
||||||
|
k: v
|
||||||
|
for k, v in sys.modules.items()
|
||||||
|
if k != config_flow_module_name and k != integration.pkg_path
|
||||||
|
}
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{**modules_without_integration},
|
||||||
|
clear=True,
|
||||||
|
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||||
|
load_task1 = asyncio.create_task(integration.async_get_component())
|
||||||
|
load_task2 = asyncio.create_task(integration.async_get_component())
|
||||||
|
await import_event.wait() # make sure the import is started
|
||||||
|
assert not integration._component_future.done()
|
||||||
|
start_event.set()
|
||||||
|
comp1 = await load_task1
|
||||||
|
comp2 = await load_task2
|
||||||
|
assert integration._component_future is None
|
||||||
|
|
||||||
|
assert comp1 is module_mock
|
||||||
|
assert comp2 is module_mock
|
||||||
|
|
||||||
|
assert imports == [integration.pkg_path, config_flow_module_name]
|
||||||
|
|
||||||
|
|
||||||
async def test_async_get_component_deadlock_fallback(
|
async def test_async_get_component_deadlock_fallback(
|
||||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user