mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Group loading of platforms in the import executor (#112141)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
e0a8a9d551
commit
5227976aa2
@ -506,7 +506,7 @@ class ConfigEntry:
|
||||
|
||||
if domain_is_integration:
|
||||
try:
|
||||
await integration.async_get_platform("config_flow")
|
||||
await integration.async_get_platforms(("config_flow",))
|
||||
except ImportError as err:
|
||||
_LOGGER.error(
|
||||
(
|
||||
@ -1814,6 +1814,8 @@ class ConfigEntries:
|
||||
self, entry: ConfigEntry, platforms: Iterable[Platform | str]
|
||||
) -> None:
|
||||
"""Forward the setup of an entry to platforms."""
|
||||
integration = await loader.async_get_integration(self.hass, entry.domain)
|
||||
await integration.async_get_platforms(platforms)
|
||||
await asyncio.gather(
|
||||
*(
|
||||
create_eager_task(
|
||||
@ -2519,7 +2521,7 @@ async def _load_integration(
|
||||
# Make sure requirements and dependencies of component are resolved
|
||||
await async_process_deps_reqs(hass, hass_config, integration)
|
||||
try:
|
||||
await integration.async_get_platform("config_flow")
|
||||
await integration.async_get_platforms(("config_flow",))
|
||||
except ImportError as err:
|
||||
_LOGGER.error(
|
||||
"Error occurred loading flow for integration %s: %s",
|
||||
|
@ -27,6 +27,7 @@ from awesomeversion import (
|
||||
import voluptuous as vol
|
||||
|
||||
from . import generated
|
||||
from .const import Platform
|
||||
from .core import HomeAssistant, callback
|
||||
from .generated.application_credentials import APPLICATION_CREDENTIALS
|
||||
from .generated.bluetooth import BLUETOOTH
|
||||
@ -663,6 +664,12 @@ class Integration:
|
||||
|
||||
self._component_future: asyncio.Future[ComponentProtocol] | None = None
|
||||
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
|
||||
cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS]
|
||||
self._cache = cache
|
||||
missing_platforms_cache: dict[str, ImportError] = hass.data[
|
||||
DATA_MISSING_PLATFORMS
|
||||
]
|
||||
self._missing_platforms_cache = missing_platforms_cache
|
||||
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
|
||||
|
||||
@cached_property
|
||||
@ -909,12 +916,14 @@ class Integration:
|
||||
with a dict cache which is thread-safe since importlib has
|
||||
appropriate locks.
|
||||
"""
|
||||
cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS]
|
||||
if self.domain in cache:
|
||||
return cache[self.domain]
|
||||
cache = self._cache
|
||||
domain = self.domain
|
||||
|
||||
if domain in cache:
|
||||
return cache[domain]
|
||||
|
||||
try:
|
||||
cache[self.domain] = cast(
|
||||
cache[domain] = cast(
|
||||
ComponentProtocol, importlib.import_module(self.pkg_path)
|
||||
)
|
||||
except ImportError:
|
||||
@ -945,75 +954,122 @@ class Integration:
|
||||
with suppress(ImportError):
|
||||
self.get_platform("config_flow")
|
||||
|
||||
return cache[self.domain]
|
||||
return cache[domain]
|
||||
|
||||
def _load_platforms(self, platform_names: Iterable[str]) -> dict[str, ModuleType]:
|
||||
"""Load platforms for an integration."""
|
||||
return {
|
||||
platform_name: self._load_platform(platform_name)
|
||||
for platform_name in platform_names
|
||||
}
|
||||
|
||||
async def async_get_platform(self, platform_name: str) -> ModuleType:
|
||||
"""Return a platform for an integration."""
|
||||
platforms = await self.async_get_platforms([platform_name])
|
||||
return platforms[platform_name]
|
||||
|
||||
async def async_get_platforms(
|
||||
self, platform_names: Iterable[Platform | str]
|
||||
) -> dict[str, ModuleType]:
|
||||
"""Return a platforms for an integration."""
|
||||
domain = self.domain
|
||||
full_name = f"{self.domain}.{platform_name}"
|
||||
if platform := self._get_platform_cached(full_name):
|
||||
return platform
|
||||
if future := self._import_futures.get(full_name):
|
||||
return await future
|
||||
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
start = time.perf_counter()
|
||||
import_future = self.hass.loop.create_future()
|
||||
self._import_futures[full_name] = import_future
|
||||
load_executor = (
|
||||
self.import_executor
|
||||
and domain not in self.hass.config.components
|
||||
and f"{self.pkg_path}.{domain}" not in sys.modules
|
||||
)
|
||||
try:
|
||||
if load_executor:
|
||||
try:
|
||||
platform = await self.hass.async_add_import_executor_job(
|
||||
self._load_platform, platform_name
|
||||
)
|
||||
except ImportError as ex:
|
||||
_LOGGER.debug(
|
||||
"Failed to import %s in executor", domain, exc_info=ex
|
||||
)
|
||||
load_executor = False
|
||||
# If importing in the executor deadlocks because there is a circular
|
||||
# dependency, we fall back to the event loop.
|
||||
platform = self._load_platform(platform_name)
|
||||
platforms: dict[str, ModuleType] = {}
|
||||
|
||||
load_executor_platforms: list[str] = []
|
||||
load_event_loop_platforms: list[str] = []
|
||||
in_progress_imports: dict[str, asyncio.Future[ModuleType]] = {}
|
||||
import_futures: list[tuple[str, asyncio.Future[ModuleType]]] = []
|
||||
|
||||
for platform_name in platform_names:
|
||||
full_name = f"{domain}.{platform_name}"
|
||||
if platform := self._get_platform_cached(full_name):
|
||||
platforms[platform_name] = platform
|
||||
continue
|
||||
|
||||
# Another call to async_get_platforms is already importing this platform
|
||||
if future := self._import_futures.get(platform_name):
|
||||
in_progress_imports[platform_name] = future
|
||||
continue
|
||||
|
||||
if (
|
||||
self.import_executor
|
||||
and full_name not in self.hass.config.components
|
||||
and f"{self.pkg_path}.{platform_name}" not in sys.modules
|
||||
):
|
||||
load_executor_platforms.append(platform_name)
|
||||
else:
|
||||
platform = self._load_platform(platform_name)
|
||||
import_future.set_result(platform)
|
||||
except BaseException as ex:
|
||||
import_future.set_exception(ex)
|
||||
with suppress(BaseException):
|
||||
# Clear the exception retrieved flag on the future since
|
||||
# it will never be retrieved unless there
|
||||
# are concurrent calls to async_get_platform
|
||||
import_future.result()
|
||||
raise
|
||||
finally:
|
||||
self._import_futures.pop(full_name)
|
||||
load_event_loop_platforms.append(platform_name)
|
||||
|
||||
if debug:
|
||||
_LOGGER.debug(
|
||||
"Importing platform %s took %.2fs (loaded_executor=%s)",
|
||||
full_name,
|
||||
time.perf_counter() - start,
|
||||
load_executor,
|
||||
)
|
||||
import_future = self.hass.loop.create_future()
|
||||
self._import_futures[platform_name] = import_future
|
||||
import_futures.append((platform_name, import_future))
|
||||
|
||||
return platform
|
||||
if load_executor_platforms or load_event_loop_platforms:
|
||||
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
start = time.perf_counter()
|
||||
|
||||
try:
|
||||
if load_executor_platforms:
|
||||
try:
|
||||
platforms.update(
|
||||
await self.hass.async_add_import_executor_job(
|
||||
self._load_platforms, platform_names
|
||||
)
|
||||
)
|
||||
except ImportError as ex:
|
||||
_LOGGER.debug(
|
||||
"Failed to import %s platforms %s in executor",
|
||||
domain,
|
||||
load_executor_platforms,
|
||||
exc_info=ex,
|
||||
)
|
||||
# If importing in the executor deadlocks because there is a circular
|
||||
# dependency, we fall back to the event loop.
|
||||
load_event_loop_platforms.extend(load_executor_platforms)
|
||||
|
||||
if load_event_loop_platforms:
|
||||
platforms.update(self._load_platforms(platform_names))
|
||||
|
||||
for platform_name, import_future in import_futures:
|
||||
import_future.set_result(platforms[platform_name])
|
||||
|
||||
except BaseException as ex:
|
||||
for _, import_future in import_futures:
|
||||
import_future.set_exception(ex)
|
||||
with suppress(BaseException):
|
||||
# Set the exception retrieved flag on the future since
|
||||
# it will never be retrieved unless there
|
||||
# are concurrent calls to async_get_platforms
|
||||
import_future.result()
|
||||
raise
|
||||
|
||||
finally:
|
||||
for platform_name, _ in import_futures:
|
||||
self._import_futures.pop(platform_name)
|
||||
|
||||
if debug:
|
||||
_LOGGER.debug(
|
||||
"Importing platforms for %s executor=%s loop=%s took %.2fs",
|
||||
domain,
|
||||
load_executor_platforms,
|
||||
load_event_loop_platforms,
|
||||
time.perf_counter() - start,
|
||||
)
|
||||
|
||||
if in_progress_imports:
|
||||
for platform_name, future in in_progress_imports.items():
|
||||
platforms[platform_name] = await future
|
||||
|
||||
return platforms
|
||||
|
||||
def _get_platform_cached(self, full_name: str) -> ModuleType | None:
|
||||
"""Return a platform for an integration from cache."""
|
||||
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
|
||||
if full_name in cache:
|
||||
return cache[full_name]
|
||||
|
||||
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
|
||||
DATA_MISSING_PLATFORMS
|
||||
]
|
||||
if full_name in missing_platforms_cache:
|
||||
raise missing_platforms_cache[full_name]
|
||||
|
||||
if full_name in self._cache:
|
||||
# the cache is either a ModuleType or a ComponentProtocol
|
||||
# but we only care about the ModuleType here
|
||||
return self._cache[full_name] # type: ignore[return-value]
|
||||
if full_name in self._missing_platforms_cache:
|
||||
raise self._missing_platforms_cache[full_name]
|
||||
return None
|
||||
|
||||
def get_platform(self, platform_name: str) -> ModuleType:
|
||||
@ -1033,14 +1089,11 @@ class Integration:
|
||||
has been called for the integration or the integration failed to load.
|
||||
"""
|
||||
full_name = f"{self.domain}.{platform_name}"
|
||||
|
||||
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
|
||||
cache = self._cache
|
||||
if full_name in cache:
|
||||
return True
|
||||
|
||||
missing_platforms_cache: dict[str, ImportError]
|
||||
missing_platforms_cache = self.hass.data[DATA_MISSING_PLATFORMS]
|
||||
if full_name in missing_platforms_cache:
|
||||
if full_name in self._missing_platforms_cache:
|
||||
return False
|
||||
|
||||
if not (component := cache.get(self.domain)) or not (
|
||||
@ -1056,7 +1109,7 @@ class Integration:
|
||||
f"Platform {full_name} not found",
|
||||
name=f"{self.pkg_path}.{platform_name}",
|
||||
)
|
||||
missing_platforms_cache[full_name] = exc
|
||||
self._missing_platforms_cache[full_name] = exc
|
||||
return False
|
||||
|
||||
def _load_platform(self, platform_name: str) -> ModuleType:
|
||||
@ -1077,10 +1130,7 @@ class Integration:
|
||||
if self.domain in cache:
|
||||
# If the domain is loaded, cache that the platform
|
||||
# does not exist so we do not try to load it again
|
||||
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
|
||||
DATA_MISSING_PLATFORMS
|
||||
]
|
||||
missing_platforms_cache[full_name] = ex
|
||||
self._missing_platforms_cache[full_name] = ex
|
||||
raise
|
||||
except RuntimeError as err:
|
||||
# _DeadlockError inherits from RuntimeError
|
||||
|
@ -278,6 +278,41 @@ async def test_async_get_platform_caches_failures_when_component_loaded(
|
||||
assert await integration.async_get_platform("light") == hue_light
|
||||
|
||||
|
||||
async def test_async_get_platforms_caches_failures_when_component_loaded(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test async_get_platforms cache failures only when the component is loaded."""
|
||||
integration = await loader.async_get_integration(hass, "hue")
|
||||
|
||||
with pytest.raises(ImportError), patch(
|
||||
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||
):
|
||||
assert integration.get_component() == hue
|
||||
|
||||
with pytest.raises(ImportError), patch(
|
||||
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||
):
|
||||
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||
|
||||
# Hue is not loaded so we should still hit the import_module path
|
||||
with pytest.raises(ImportError), patch(
|
||||
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||
):
|
||||
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||
|
||||
assert integration.get_component() == hue
|
||||
|
||||
# Hue is loaded so we should cache the import_module failure now
|
||||
with pytest.raises(ImportError), patch(
|
||||
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||
):
|
||||
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||
|
||||
# Hue is loaded and the last call should have cached the import_module failure
|
||||
with pytest.raises(ImportError):
|
||||
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||
|
||||
|
||||
async def test_get_integration_legacy(
|
||||
hass: HomeAssistant, enable_custom_integrations: None
|
||||
) -> None:
|
||||
@ -1302,7 +1337,9 @@ async def test_async_get_platform_deadlock_fallback(
|
||||
"Detected deadlock trying to import homeassistant.components.executor_import"
|
||||
in caplog.text
|
||||
)
|
||||
assert "loaded_executor=False" in caplog.text
|
||||
# We should have tried both the executor and loop
|
||||
assert "executor=['config_flow']" in caplog.text
|
||||
assert "loop=['config_flow']" in caplog.text
|
||||
assert module is module_mock
|
||||
|
||||
|
||||
@ -1390,3 +1427,155 @@ async def test_platform_exists(
|
||||
assert platform.MAGIC == 1
|
||||
|
||||
assert integration.platform_exists("group") is True
|
||||
|
||||
|
||||
async def test_async_get_platforms_loads_loop_if_already_in_sys_modules(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
enable_custom_integrations: None,
|
||||
) -> None:
|
||||
"""Verify async_get_platforms does not create an executor job.
|
||||
|
||||
Case is for when the module is already in sys.modules.
|
||||
"""
|
||||
integration = await loader.async_get_integration(
|
||||
hass, "test_package_loaded_executor"
|
||||
)
|
||||
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
|
||||
assert integration.import_executor is True
|
||||
assert integration.config_flow is True
|
||||
|
||||
assert "test_package_loaded_executor" not in hass.config.components
|
||||
assert "test_package_loaded_executor.config_flow" not in hass.config.components
|
||||
await integration.async_get_component()
|
||||
|
||||
button_module_name = f"{integration.pkg_path}.button"
|
||||
switch_module_name = f"{integration.pkg_path}.switch"
|
||||
light_module_name = f"{integration.pkg_path}.light"
|
||||
button_module_mock = MagicMock()
|
||||
switch_module_mock = MagicMock()
|
||||
light_module_mock = MagicMock()
|
||||
|
||||
def import_module(name: str) -> Any:
|
||||
if name == button_module_name:
|
||||
return button_module_mock
|
||||
if name == switch_module_name:
|
||||
return switch_module_mock
|
||||
if name == light_module_name:
|
||||
return light_module_mock
|
||||
raise ImportError
|
||||
|
||||
modules_without_button = {
|
||||
k: v for k, v in sys.modules.items() if k != button_module_name
|
||||
}
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
modules_without_button,
|
||||
clear=True,
|
||||
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||
module = (await integration.async_get_platforms(["button"]))["button"]
|
||||
|
||||
# The button module is missing so we should load
|
||||
# in the executor
|
||||
assert "executor=['button']" in caplog.text
|
||||
assert "loop=[]" in caplog.text
|
||||
assert module is button_module_mock
|
||||
caplog.clear()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
**modules_without_button,
|
||||
button_module_name: button_module_mock,
|
||||
},
|
||||
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||
module = (await integration.async_get_platforms(["button"]))["button"]
|
||||
|
||||
# Everything is cached so there should be no logging
|
||||
assert "loop=" not in caplog.text
|
||||
assert "executor=" not in caplog.text
|
||||
assert module is button_module_mock
|
||||
caplog.clear()
|
||||
|
||||
modules_without_switch = {
|
||||
k: v for k, v in sys.modules.items() if k not in switch_module_name
|
||||
}
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{**modules_without_switch, light_module_name: light_module_mock},
|
||||
clear=True,
|
||||
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||
modules = await integration.async_get_platforms(["button", "switch", "light"])
|
||||
|
||||
# The button module is already in the cache so nothing happens
|
||||
# The switch module is loaded in the executor since its not in the cache
|
||||
# The light module is in memory but not in the cache so its loaded in the loop
|
||||
assert "['button']" not in caplog.text
|
||||
assert "executor=['switch']" in caplog.text
|
||||
assert "loop=['light']" in caplog.text
|
||||
assert modules == {
|
||||
"button": button_module_mock,
|
||||
"switch": switch_module_mock,
|
||||
"light": light_module_mock,
|
||||
}
|
||||
|
||||
|
||||
async def test_async_get_platforms_concurrent_loads(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
enable_custom_integrations: None,
|
||||
) -> None:
|
||||
"""Verify async_get_platforms waits if the first load if called again.
|
||||
|
||||
Case is for when when a second load is called
|
||||
and the first is still in progress.
|
||||
"""
|
||||
integration = await loader.async_get_integration(
|
||||
hass, "test_package_loaded_executor"
|
||||
)
|
||||
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
|
||||
assert integration.import_executor is True
|
||||
assert integration.config_flow is True
|
||||
|
||||
assert "test_package_loaded_executor" not in hass.config.components
|
||||
assert "test_package_loaded_executor.config_flow" not in hass.config.components
|
||||
await integration.async_get_component()
|
||||
|
||||
button_module_name = f"{integration.pkg_path}.button"
|
||||
button_module_mock = MagicMock()
|
||||
|
||||
imports = []
|
||||
start_event = threading.Event()
|
||||
import_event = asyncio.Event()
|
||||
|
||||
def import_module(name: str) -> Any:
|
||||
hass.loop.call_soon_threadsafe(import_event.set)
|
||||
imports.append(name)
|
||||
start_event.wait()
|
||||
if name == button_module_name:
|
||||
return button_module_mock
|
||||
raise ImportError
|
||||
|
||||
modules_without_button = {
|
||||
k: v
|
||||
for k, v in sys.modules.items()
|
||||
if k != button_module_name and k != integration.pkg_path
|
||||
}
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
modules_without_button,
|
||||
clear=True,
|
||||
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||
load_task1 = asyncio.create_task(integration.async_get_platforms(["button"]))
|
||||
load_task2 = asyncio.create_task(integration.async_get_platforms(["button"]))
|
||||
await import_event.wait() # make sure the import is started
|
||||
assert not integration._import_futures["button"].done()
|
||||
start_event.set()
|
||||
load_result1 = await load_task1
|
||||
load_result2 = await load_task2
|
||||
assert integration._import_futures is not None
|
||||
|
||||
assert load_result1 == {"button": button_module_mock}
|
||||
assert load_result2 == {"button": button_module_mock}
|
||||
|
||||
assert imports == [button_module_name]
|
||||
|
@ -0,0 +1 @@
|
||||
"""Provide a mock button platform."""
|
@ -0,0 +1 @@
|
||||
"""Provide a mock light platform."""
|
@ -0,0 +1 @@
|
||||
"""Provide a mock switch platform."""
|
Loading…
x
Reference in New Issue
Block a user