Group loading of platforms in the import executor (#112141)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2024-03-03 21:32:19 -10:00 committed by GitHub
parent e0a8a9d551
commit 5227976aa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 320 additions and 76 deletions

View File

@ -506,7 +506,7 @@ class ConfigEntry:
if domain_is_integration:
try:
await integration.async_get_platform("config_flow")
await integration.async_get_platforms(("config_flow",))
except ImportError as err:
_LOGGER.error(
(
@ -1814,6 +1814,8 @@ class ConfigEntries:
self, entry: ConfigEntry, platforms: Iterable[Platform | str]
) -> None:
"""Forward the setup of an entry to platforms."""
integration = await loader.async_get_integration(self.hass, entry.domain)
await integration.async_get_platforms(platforms)
await asyncio.gather(
*(
create_eager_task(
@ -2519,7 +2521,7 @@ async def _load_integration(
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(hass, hass_config, integration)
try:
await integration.async_get_platform("config_flow")
await integration.async_get_platforms(("config_flow",))
except ImportError as err:
_LOGGER.error(
"Error occurred loading flow for integration %s: %s",

View File

@ -27,6 +27,7 @@ from awesomeversion import (
import voluptuous as vol
from . import generated
from .const import Platform
from .core import HomeAssistant, callback
from .generated.application_credentials import APPLICATION_CREDENTIALS
from .generated.bluetooth import BLUETOOTH
@ -663,6 +664,12 @@ class Integration:
self._component_future: asyncio.Future[ComponentProtocol] | None = None
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS]
self._cache = cache
missing_platforms_cache: dict[str, ImportError] = hass.data[
DATA_MISSING_PLATFORMS
]
self._missing_platforms_cache = missing_platforms_cache
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
@cached_property
@ -909,12 +916,14 @@ class Integration:
with a dict cache which is thread-safe since importlib has
appropriate locks.
"""
cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS]
if self.domain in cache:
return cache[self.domain]
cache = self._cache
domain = self.domain
if domain in cache:
return cache[domain]
try:
cache[self.domain] = cast(
cache[domain] = cast(
ComponentProtocol, importlib.import_module(self.pkg_path)
)
except ImportError:
@ -945,75 +954,122 @@ class Integration:
with suppress(ImportError):
self.get_platform("config_flow")
return cache[self.domain]
return cache[domain]
def _load_platforms(self, platform_names: Iterable[str]) -> dict[str, ModuleType]:
"""Load platforms for an integration."""
return {
platform_name: self._load_platform(platform_name)
for platform_name in platform_names
}
async def async_get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration."""
platforms = await self.async_get_platforms([platform_name])
return platforms[platform_name]
async def async_get_platforms(
self, platform_names: Iterable[Platform | str]
) -> dict[str, ModuleType]:
"""Return a platforms for an integration."""
domain = self.domain
full_name = f"{self.domain}.{platform_name}"
if platform := self._get_platform_cached(full_name):
return platform
if future := self._import_futures.get(full_name):
return await future
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
start = time.perf_counter()
import_future = self.hass.loop.create_future()
self._import_futures[full_name] = import_future
load_executor = (
self.import_executor
and domain not in self.hass.config.components
and f"{self.pkg_path}.{domain}" not in sys.modules
)
try:
if load_executor:
try:
platform = await self.hass.async_add_import_executor_job(
self._load_platform, platform_name
)
except ImportError as ex:
_LOGGER.debug(
"Failed to import %s in executor", domain, exc_info=ex
)
load_executor = False
# If importing in the executor deadlocks because there is a circular
# dependency, we fall back to the event loop.
platform = self._load_platform(platform_name)
platforms: dict[str, ModuleType] = {}
load_executor_platforms: list[str] = []
load_event_loop_platforms: list[str] = []
in_progress_imports: dict[str, asyncio.Future[ModuleType]] = {}
import_futures: list[tuple[str, asyncio.Future[ModuleType]]] = []
for platform_name in platform_names:
full_name = f"{domain}.{platform_name}"
if platform := self._get_platform_cached(full_name):
platforms[platform_name] = platform
continue
# Another call to async_get_platforms is already importing this platform
if future := self._import_futures.get(platform_name):
in_progress_imports[platform_name] = future
continue
if (
self.import_executor
and full_name not in self.hass.config.components
and f"{self.pkg_path}.{platform_name}" not in sys.modules
):
load_executor_platforms.append(platform_name)
else:
platform = self._load_platform(platform_name)
import_future.set_result(platform)
except BaseException as ex:
import_future.set_exception(ex)
with suppress(BaseException):
# Clear the exception retrieved flag on the future since
# it will never be retrieved unless there
# are concurrent calls to async_get_platform
import_future.result()
raise
finally:
self._import_futures.pop(full_name)
load_event_loop_platforms.append(platform_name)
if debug:
_LOGGER.debug(
"Importing platform %s took %.2fs (loaded_executor=%s)",
full_name,
time.perf_counter() - start,
load_executor,
)
import_future = self.hass.loop.create_future()
self._import_futures[platform_name] = import_future
import_futures.append((platform_name, import_future))
return platform
if load_executor_platforms or load_event_loop_platforms:
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
start = time.perf_counter()
try:
if load_executor_platforms:
try:
platforms.update(
await self.hass.async_add_import_executor_job(
self._load_platforms, platform_names
)
)
except ImportError as ex:
_LOGGER.debug(
"Failed to import %s platforms %s in executor",
domain,
load_executor_platforms,
exc_info=ex,
)
# If importing in the executor deadlocks because there is a circular
# dependency, we fall back to the event loop.
load_event_loop_platforms.extend(load_executor_platforms)
if load_event_loop_platforms:
platforms.update(self._load_platforms(platform_names))
for platform_name, import_future in import_futures:
import_future.set_result(platforms[platform_name])
except BaseException as ex:
for _, import_future in import_futures:
import_future.set_exception(ex)
with suppress(BaseException):
# Set the exception retrieved flag on the future since
# it will never be retrieved unless there
# are concurrent calls to async_get_platforms
import_future.result()
raise
finally:
for platform_name, _ in import_futures:
self._import_futures.pop(platform_name)
if debug:
_LOGGER.debug(
"Importing platforms for %s executor=%s loop=%s took %.2fs",
domain,
load_executor_platforms,
load_event_loop_platforms,
time.perf_counter() - start,
)
if in_progress_imports:
for platform_name, future in in_progress_imports.items():
platforms[platform_name] = await future
return platforms
def _get_platform_cached(self, full_name: str) -> ModuleType | None:
"""Return a platform for an integration from cache."""
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
if full_name in cache:
return cache[full_name]
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
DATA_MISSING_PLATFORMS
]
if full_name in missing_platforms_cache:
raise missing_platforms_cache[full_name]
if full_name in self._cache:
# the cache is either a ModuleType or a ComponentProtocol
# but we only care about the ModuleType here
return self._cache[full_name] # type: ignore[return-value]
if full_name in self._missing_platforms_cache:
raise self._missing_platforms_cache[full_name]
return None
def get_platform(self, platform_name: str) -> ModuleType:
@ -1033,14 +1089,11 @@ class Integration:
has been called for the integration or the integration failed to load.
"""
full_name = f"{self.domain}.{platform_name}"
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
cache = self._cache
if full_name in cache:
return True
missing_platforms_cache: dict[str, ImportError]
missing_platforms_cache = self.hass.data[DATA_MISSING_PLATFORMS]
if full_name in missing_platforms_cache:
if full_name in self._missing_platforms_cache:
return False
if not (component := cache.get(self.domain)) or not (
@ -1056,7 +1109,7 @@ class Integration:
f"Platform {full_name} not found",
name=f"{self.pkg_path}.{platform_name}",
)
missing_platforms_cache[full_name] = exc
self._missing_platforms_cache[full_name] = exc
return False
def _load_platform(self, platform_name: str) -> ModuleType:
@ -1077,10 +1130,7 @@ class Integration:
if self.domain in cache:
# If the domain is loaded, cache that the platform
# does not exist so we do not try to load it again
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
DATA_MISSING_PLATFORMS
]
missing_platforms_cache[full_name] = ex
self._missing_platforms_cache[full_name] = ex
raise
except RuntimeError as err:
# _DeadlockError inherits from RuntimeError

View File

@ -278,6 +278,41 @@ async def test_async_get_platform_caches_failures_when_component_loaded(
assert await integration.async_get_platform("light") == hue_light
async def test_async_get_platforms_caches_failures_when_component_loaded(
hass: HomeAssistant,
) -> None:
"""Test async_get_platforms cache failures only when the component is loaded."""
integration = await loader.async_get_integration(hass, "hue")
with pytest.raises(ImportError), patch(
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
):
assert integration.get_component() == hue
with pytest.raises(ImportError), patch(
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
# Hue is not loaded so we should still hit the import_module path
with pytest.raises(ImportError), patch(
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
assert integration.get_component() == hue
# Hue is loaded so we should cache the import_module failure now
with pytest.raises(ImportError), patch(
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
# Hue is loaded and the last call should have cached the import_module failure
with pytest.raises(ImportError):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
async def test_get_integration_legacy(
hass: HomeAssistant, enable_custom_integrations: None
) -> None:
@ -1302,7 +1337,9 @@ async def test_async_get_platform_deadlock_fallback(
"Detected deadlock trying to import homeassistant.components.executor_import"
in caplog.text
)
assert "loaded_executor=False" in caplog.text
# We should have tried both the executor and loop
assert "executor=['config_flow']" in caplog.text
assert "loop=['config_flow']" in caplog.text
assert module is module_mock
@ -1390,3 +1427,155 @@ async def test_platform_exists(
assert platform.MAGIC == 1
assert integration.platform_exists("group") is True
async def test_async_get_platforms_loads_loop_if_already_in_sys_modules(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
enable_custom_integrations: None,
) -> None:
"""Verify async_get_platforms does not create an executor job.
Case is for when the module is already in sys.modules.
"""
integration = await loader.async_get_integration(
hass, "test_package_loaded_executor"
)
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
assert integration.import_executor is True
assert integration.config_flow is True
assert "test_package_loaded_executor" not in hass.config.components
assert "test_package_loaded_executor.config_flow" not in hass.config.components
await integration.async_get_component()
button_module_name = f"{integration.pkg_path}.button"
switch_module_name = f"{integration.pkg_path}.switch"
light_module_name = f"{integration.pkg_path}.light"
button_module_mock = MagicMock()
switch_module_mock = MagicMock()
light_module_mock = MagicMock()
def import_module(name: str) -> Any:
if name == button_module_name:
return button_module_mock
if name == switch_module_name:
return switch_module_mock
if name == light_module_name:
return light_module_mock
raise ImportError
modules_without_button = {
k: v for k, v in sys.modules.items() if k != button_module_name
}
with patch.dict(
"sys.modules",
modules_without_button,
clear=True,
), patch("homeassistant.loader.importlib.import_module", import_module):
module = (await integration.async_get_platforms(["button"]))["button"]
# The button module is missing so we should load
# in the executor
assert "executor=['button']" in caplog.text
assert "loop=[]" in caplog.text
assert module is button_module_mock
caplog.clear()
with patch.dict(
"sys.modules",
{
**modules_without_button,
button_module_name: button_module_mock,
},
), patch("homeassistant.loader.importlib.import_module", import_module):
module = (await integration.async_get_platforms(["button"]))["button"]
# Everything is cached so there should be no logging
assert "loop=" not in caplog.text
assert "executor=" not in caplog.text
assert module is button_module_mock
caplog.clear()
modules_without_switch = {
k: v for k, v in sys.modules.items() if k not in switch_module_name
}
with patch.dict(
"sys.modules",
{**modules_without_switch, light_module_name: light_module_mock},
clear=True,
), patch("homeassistant.loader.importlib.import_module", import_module):
modules = await integration.async_get_platforms(["button", "switch", "light"])
# The button module is already in the cache so nothing happens
# The switch module is loaded in the executor since its not in the cache
# The light module is in memory but not in the cache so its loaded in the loop
assert "['button']" not in caplog.text
assert "executor=['switch']" in caplog.text
assert "loop=['light']" in caplog.text
assert modules == {
"button": button_module_mock,
"switch": switch_module_mock,
"light": light_module_mock,
}
async def test_async_get_platforms_concurrent_loads(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
enable_custom_integrations: None,
) -> None:
"""Verify async_get_platforms waits if the first load if called again.
Case is for when when a second load is called
and the first is still in progress.
"""
integration = await loader.async_get_integration(
hass, "test_package_loaded_executor"
)
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
assert integration.import_executor is True
assert integration.config_flow is True
assert "test_package_loaded_executor" not in hass.config.components
assert "test_package_loaded_executor.config_flow" not in hass.config.components
await integration.async_get_component()
button_module_name = f"{integration.pkg_path}.button"
button_module_mock = MagicMock()
imports = []
start_event = threading.Event()
import_event = asyncio.Event()
def import_module(name: str) -> Any:
hass.loop.call_soon_threadsafe(import_event.set)
imports.append(name)
start_event.wait()
if name == button_module_name:
return button_module_mock
raise ImportError
modules_without_button = {
k: v
for k, v in sys.modules.items()
if k != button_module_name and k != integration.pkg_path
}
with patch.dict(
"sys.modules",
modules_without_button,
clear=True,
), patch("homeassistant.loader.importlib.import_module", import_module):
load_task1 = asyncio.create_task(integration.async_get_platforms(["button"]))
load_task2 = asyncio.create_task(integration.async_get_platforms(["button"]))
await import_event.wait() # make sure the import is started
assert not integration._import_futures["button"].done()
start_event.set()
load_result1 = await load_task1
load_result2 = await load_task2
assert integration._import_futures is not None
assert load_result1 == {"button": button_module_mock}
assert load_result2 == {"button": button_module_mock}
assert imports == [button_module_name]

View File

@ -0,0 +1 @@
"""Provide a mock button platform."""

View File

@ -0,0 +1 @@
"""Provide a mock light platform."""

View File

@ -0,0 +1 @@
"""Provide a mock switch platform."""