mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Group loading of platforms in the import executor (#112141)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
e0a8a9d551
commit
5227976aa2
@ -506,7 +506,7 @@ class ConfigEntry:
|
|||||||
|
|
||||||
if domain_is_integration:
|
if domain_is_integration:
|
||||||
try:
|
try:
|
||||||
await integration.async_get_platform("config_flow")
|
await integration.async_get_platforms(("config_flow",))
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
(
|
(
|
||||||
@ -1814,6 +1814,8 @@ class ConfigEntries:
|
|||||||
self, entry: ConfigEntry, platforms: Iterable[Platform | str]
|
self, entry: ConfigEntry, platforms: Iterable[Platform | str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Forward the setup of an entry to platforms."""
|
"""Forward the setup of an entry to platforms."""
|
||||||
|
integration = await loader.async_get_integration(self.hass, entry.domain)
|
||||||
|
await integration.async_get_platforms(platforms)
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
create_eager_task(
|
create_eager_task(
|
||||||
@ -2519,7 +2521,7 @@ async def _load_integration(
|
|||||||
# Make sure requirements and dependencies of component are resolved
|
# Make sure requirements and dependencies of component are resolved
|
||||||
await async_process_deps_reqs(hass, hass_config, integration)
|
await async_process_deps_reqs(hass, hass_config, integration)
|
||||||
try:
|
try:
|
||||||
await integration.async_get_platform("config_flow")
|
await integration.async_get_platforms(("config_flow",))
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"Error occurred loading flow for integration %s: %s",
|
"Error occurred loading flow for integration %s: %s",
|
||||||
|
@ -27,6 +27,7 @@ from awesomeversion import (
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from . import generated
|
from . import generated
|
||||||
|
from .const import Platform
|
||||||
from .core import HomeAssistant, callback
|
from .core import HomeAssistant, callback
|
||||||
from .generated.application_credentials import APPLICATION_CREDENTIALS
|
from .generated.application_credentials import APPLICATION_CREDENTIALS
|
||||||
from .generated.bluetooth import BLUETOOTH
|
from .generated.bluetooth import BLUETOOTH
|
||||||
@ -663,6 +664,12 @@ class Integration:
|
|||||||
|
|
||||||
self._component_future: asyncio.Future[ComponentProtocol] | None = None
|
self._component_future: asyncio.Future[ComponentProtocol] | None = None
|
||||||
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
|
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
|
||||||
|
cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS]
|
||||||
|
self._cache = cache
|
||||||
|
missing_platforms_cache: dict[str, ImportError] = hass.data[
|
||||||
|
DATA_MISSING_PLATFORMS
|
||||||
|
]
|
||||||
|
self._missing_platforms_cache = missing_platforms_cache
|
||||||
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
|
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@ -909,12 +916,14 @@ class Integration:
|
|||||||
with a dict cache which is thread-safe since importlib has
|
with a dict cache which is thread-safe since importlib has
|
||||||
appropriate locks.
|
appropriate locks.
|
||||||
"""
|
"""
|
||||||
cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS]
|
cache = self._cache
|
||||||
if self.domain in cache:
|
domain = self.domain
|
||||||
return cache[self.domain]
|
|
||||||
|
if domain in cache:
|
||||||
|
return cache[domain]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cache[self.domain] = cast(
|
cache[domain] = cast(
|
||||||
ComponentProtocol, importlib.import_module(self.pkg_path)
|
ComponentProtocol, importlib.import_module(self.pkg_path)
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -945,75 +954,122 @@ class Integration:
|
|||||||
with suppress(ImportError):
|
with suppress(ImportError):
|
||||||
self.get_platform("config_flow")
|
self.get_platform("config_flow")
|
||||||
|
|
||||||
return cache[self.domain]
|
return cache[domain]
|
||||||
|
|
||||||
|
def _load_platforms(self, platform_names: Iterable[str]) -> dict[str, ModuleType]:
|
||||||
|
"""Load platforms for an integration."""
|
||||||
|
return {
|
||||||
|
platform_name: self._load_platform(platform_name)
|
||||||
|
for platform_name in platform_names
|
||||||
|
}
|
||||||
|
|
||||||
async def async_get_platform(self, platform_name: str) -> ModuleType:
|
async def async_get_platform(self, platform_name: str) -> ModuleType:
|
||||||
"""Return a platform for an integration."""
|
"""Return a platform for an integration."""
|
||||||
|
platforms = await self.async_get_platforms([platform_name])
|
||||||
|
return platforms[platform_name]
|
||||||
|
|
||||||
|
async def async_get_platforms(
|
||||||
|
self, platform_names: Iterable[Platform | str]
|
||||||
|
) -> dict[str, ModuleType]:
|
||||||
|
"""Return a platforms for an integration."""
|
||||||
domain = self.domain
|
domain = self.domain
|
||||||
full_name = f"{self.domain}.{platform_name}"
|
platforms: dict[str, ModuleType] = {}
|
||||||
if platform := self._get_platform_cached(full_name):
|
|
||||||
return platform
|
load_executor_platforms: list[str] = []
|
||||||
if future := self._import_futures.get(full_name):
|
load_event_loop_platforms: list[str] = []
|
||||||
return await future
|
in_progress_imports: dict[str, asyncio.Future[ModuleType]] = {}
|
||||||
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
import_futures: list[tuple[str, asyncio.Future[ModuleType]]] = []
|
||||||
start = time.perf_counter()
|
|
||||||
import_future = self.hass.loop.create_future()
|
for platform_name in platform_names:
|
||||||
self._import_futures[full_name] = import_future
|
full_name = f"{domain}.{platform_name}"
|
||||||
load_executor = (
|
if platform := self._get_platform_cached(full_name):
|
||||||
self.import_executor
|
platforms[platform_name] = platform
|
||||||
and domain not in self.hass.config.components
|
continue
|
||||||
and f"{self.pkg_path}.{domain}" not in sys.modules
|
|
||||||
)
|
# Another call to async_get_platforms is already importing this platform
|
||||||
try:
|
if future := self._import_futures.get(platform_name):
|
||||||
if load_executor:
|
in_progress_imports[platform_name] = future
|
||||||
try:
|
continue
|
||||||
platform = await self.hass.async_add_import_executor_job(
|
|
||||||
self._load_platform, platform_name
|
if (
|
||||||
)
|
self.import_executor
|
||||||
except ImportError as ex:
|
and full_name not in self.hass.config.components
|
||||||
_LOGGER.debug(
|
and f"{self.pkg_path}.{platform_name}" not in sys.modules
|
||||||
"Failed to import %s in executor", domain, exc_info=ex
|
):
|
||||||
)
|
load_executor_platforms.append(platform_name)
|
||||||
load_executor = False
|
|
||||||
# If importing in the executor deadlocks because there is a circular
|
|
||||||
# dependency, we fall back to the event loop.
|
|
||||||
platform = self._load_platform(platform_name)
|
|
||||||
else:
|
else:
|
||||||
platform = self._load_platform(platform_name)
|
load_event_loop_platforms.append(platform_name)
|
||||||
import_future.set_result(platform)
|
|
||||||
except BaseException as ex:
|
|
||||||
import_future.set_exception(ex)
|
|
||||||
with suppress(BaseException):
|
|
||||||
# Clear the exception retrieved flag on the future since
|
|
||||||
# it will never be retrieved unless there
|
|
||||||
# are concurrent calls to async_get_platform
|
|
||||||
import_future.result()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._import_futures.pop(full_name)
|
|
||||||
|
|
||||||
if debug:
|
import_future = self.hass.loop.create_future()
|
||||||
_LOGGER.debug(
|
self._import_futures[platform_name] = import_future
|
||||||
"Importing platform %s took %.2fs (loaded_executor=%s)",
|
import_futures.append((platform_name, import_future))
|
||||||
full_name,
|
|
||||||
time.perf_counter() - start,
|
|
||||||
load_executor,
|
|
||||||
)
|
|
||||||
|
|
||||||
return platform
|
if load_executor_platforms or load_event_loop_platforms:
|
||||||
|
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if load_executor_platforms:
|
||||||
|
try:
|
||||||
|
platforms.update(
|
||||||
|
await self.hass.async_add_import_executor_job(
|
||||||
|
self._load_platforms, platform_names
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ImportError as ex:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Failed to import %s platforms %s in executor",
|
||||||
|
domain,
|
||||||
|
load_executor_platforms,
|
||||||
|
exc_info=ex,
|
||||||
|
)
|
||||||
|
# If importing in the executor deadlocks because there is a circular
|
||||||
|
# dependency, we fall back to the event loop.
|
||||||
|
load_event_loop_platforms.extend(load_executor_platforms)
|
||||||
|
|
||||||
|
if load_event_loop_platforms:
|
||||||
|
platforms.update(self._load_platforms(platform_names))
|
||||||
|
|
||||||
|
for platform_name, import_future in import_futures:
|
||||||
|
import_future.set_result(platforms[platform_name])
|
||||||
|
|
||||||
|
except BaseException as ex:
|
||||||
|
for _, import_future in import_futures:
|
||||||
|
import_future.set_exception(ex)
|
||||||
|
with suppress(BaseException):
|
||||||
|
# Set the exception retrieved flag on the future since
|
||||||
|
# it will never be retrieved unless there
|
||||||
|
# are concurrent calls to async_get_platforms
|
||||||
|
import_future.result()
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for platform_name, _ in import_futures:
|
||||||
|
self._import_futures.pop(platform_name)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Importing platforms for %s executor=%s loop=%s took %.2fs",
|
||||||
|
domain,
|
||||||
|
load_executor_platforms,
|
||||||
|
load_event_loop_platforms,
|
||||||
|
time.perf_counter() - start,
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_progress_imports:
|
||||||
|
for platform_name, future in in_progress_imports.items():
|
||||||
|
platforms[platform_name] = await future
|
||||||
|
|
||||||
|
return platforms
|
||||||
|
|
||||||
def _get_platform_cached(self, full_name: str) -> ModuleType | None:
|
def _get_platform_cached(self, full_name: str) -> ModuleType | None:
|
||||||
"""Return a platform for an integration from cache."""
|
"""Return a platform for an integration from cache."""
|
||||||
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
|
if full_name in self._cache:
|
||||||
if full_name in cache:
|
# the cache is either a ModuleType or a ComponentProtocol
|
||||||
return cache[full_name]
|
# but we only care about the ModuleType here
|
||||||
|
return self._cache[full_name] # type: ignore[return-value]
|
||||||
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
|
if full_name in self._missing_platforms_cache:
|
||||||
DATA_MISSING_PLATFORMS
|
raise self._missing_platforms_cache[full_name]
|
||||||
]
|
|
||||||
if full_name in missing_platforms_cache:
|
|
||||||
raise missing_platforms_cache[full_name]
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_platform(self, platform_name: str) -> ModuleType:
|
def get_platform(self, platform_name: str) -> ModuleType:
|
||||||
@ -1033,14 +1089,11 @@ class Integration:
|
|||||||
has been called for the integration or the integration failed to load.
|
has been called for the integration or the integration failed to load.
|
||||||
"""
|
"""
|
||||||
full_name = f"{self.domain}.{platform_name}"
|
full_name = f"{self.domain}.{platform_name}"
|
||||||
|
cache = self._cache
|
||||||
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
|
|
||||||
if full_name in cache:
|
if full_name in cache:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
missing_platforms_cache: dict[str, ImportError]
|
if full_name in self._missing_platforms_cache:
|
||||||
missing_platforms_cache = self.hass.data[DATA_MISSING_PLATFORMS]
|
|
||||||
if full_name in missing_platforms_cache:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not (component := cache.get(self.domain)) or not (
|
if not (component := cache.get(self.domain)) or not (
|
||||||
@ -1056,7 +1109,7 @@ class Integration:
|
|||||||
f"Platform {full_name} not found",
|
f"Platform {full_name} not found",
|
||||||
name=f"{self.pkg_path}.{platform_name}",
|
name=f"{self.pkg_path}.{platform_name}",
|
||||||
)
|
)
|
||||||
missing_platforms_cache[full_name] = exc
|
self._missing_platforms_cache[full_name] = exc
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _load_platform(self, platform_name: str) -> ModuleType:
|
def _load_platform(self, platform_name: str) -> ModuleType:
|
||||||
@ -1077,10 +1130,7 @@ class Integration:
|
|||||||
if self.domain in cache:
|
if self.domain in cache:
|
||||||
# If the domain is loaded, cache that the platform
|
# If the domain is loaded, cache that the platform
|
||||||
# does not exist so we do not try to load it again
|
# does not exist so we do not try to load it again
|
||||||
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
|
self._missing_platforms_cache[full_name] = ex
|
||||||
DATA_MISSING_PLATFORMS
|
|
||||||
]
|
|
||||||
missing_platforms_cache[full_name] = ex
|
|
||||||
raise
|
raise
|
||||||
except RuntimeError as err:
|
except RuntimeError as err:
|
||||||
# _DeadlockError inherits from RuntimeError
|
# _DeadlockError inherits from RuntimeError
|
||||||
|
@ -278,6 +278,41 @@ async def test_async_get_platform_caches_failures_when_component_loaded(
|
|||||||
assert await integration.async_get_platform("light") == hue_light
|
assert await integration.async_get_platform("light") == hue_light
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_platforms_caches_failures_when_component_loaded(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_platforms cache failures only when the component is loaded."""
|
||||||
|
integration = await loader.async_get_integration(hass, "hue")
|
||||||
|
|
||||||
|
with pytest.raises(ImportError), patch(
|
||||||
|
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||||
|
):
|
||||||
|
assert integration.get_component() == hue
|
||||||
|
|
||||||
|
with pytest.raises(ImportError), patch(
|
||||||
|
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||||
|
):
|
||||||
|
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||||
|
|
||||||
|
# Hue is not loaded so we should still hit the import_module path
|
||||||
|
with pytest.raises(ImportError), patch(
|
||||||
|
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||||
|
):
|
||||||
|
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||||
|
|
||||||
|
assert integration.get_component() == hue
|
||||||
|
|
||||||
|
# Hue is loaded so we should cache the import_module failure now
|
||||||
|
with pytest.raises(ImportError), patch(
|
||||||
|
"homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom")
|
||||||
|
):
|
||||||
|
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||||
|
|
||||||
|
# Hue is loaded and the last call should have cached the import_module failure
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||||
|
|
||||||
|
|
||||||
async def test_get_integration_legacy(
|
async def test_get_integration_legacy(
|
||||||
hass: HomeAssistant, enable_custom_integrations: None
|
hass: HomeAssistant, enable_custom_integrations: None
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1302,7 +1337,9 @@ async def test_async_get_platform_deadlock_fallback(
|
|||||||
"Detected deadlock trying to import homeassistant.components.executor_import"
|
"Detected deadlock trying to import homeassistant.components.executor_import"
|
||||||
in caplog.text
|
in caplog.text
|
||||||
)
|
)
|
||||||
assert "loaded_executor=False" in caplog.text
|
# We should have tried both the executor and loop
|
||||||
|
assert "executor=['config_flow']" in caplog.text
|
||||||
|
assert "loop=['config_flow']" in caplog.text
|
||||||
assert module is module_mock
|
assert module is module_mock
|
||||||
|
|
||||||
|
|
||||||
@ -1390,3 +1427,155 @@ async def test_platform_exists(
|
|||||||
assert platform.MAGIC == 1
|
assert platform.MAGIC == 1
|
||||||
|
|
||||||
assert integration.platform_exists("group") is True
|
assert integration.platform_exists("group") is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_platforms_loads_loop_if_already_in_sys_modules(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
enable_custom_integrations: None,
|
||||||
|
) -> None:
|
||||||
|
"""Verify async_get_platforms does not create an executor job.
|
||||||
|
|
||||||
|
Case is for when the module is already in sys.modules.
|
||||||
|
"""
|
||||||
|
integration = await loader.async_get_integration(
|
||||||
|
hass, "test_package_loaded_executor"
|
||||||
|
)
|
||||||
|
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
|
||||||
|
assert integration.import_executor is True
|
||||||
|
assert integration.config_flow is True
|
||||||
|
|
||||||
|
assert "test_package_loaded_executor" not in hass.config.components
|
||||||
|
assert "test_package_loaded_executor.config_flow" not in hass.config.components
|
||||||
|
await integration.async_get_component()
|
||||||
|
|
||||||
|
button_module_name = f"{integration.pkg_path}.button"
|
||||||
|
switch_module_name = f"{integration.pkg_path}.switch"
|
||||||
|
light_module_name = f"{integration.pkg_path}.light"
|
||||||
|
button_module_mock = MagicMock()
|
||||||
|
switch_module_mock = MagicMock()
|
||||||
|
light_module_mock = MagicMock()
|
||||||
|
|
||||||
|
def import_module(name: str) -> Any:
|
||||||
|
if name == button_module_name:
|
||||||
|
return button_module_mock
|
||||||
|
if name == switch_module_name:
|
||||||
|
return switch_module_mock
|
||||||
|
if name == light_module_name:
|
||||||
|
return light_module_mock
|
||||||
|
raise ImportError
|
||||||
|
|
||||||
|
modules_without_button = {
|
||||||
|
k: v for k, v in sys.modules.items() if k != button_module_name
|
||||||
|
}
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
modules_without_button,
|
||||||
|
clear=True,
|
||||||
|
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||||
|
module = (await integration.async_get_platforms(["button"]))["button"]
|
||||||
|
|
||||||
|
# The button module is missing so we should load
|
||||||
|
# in the executor
|
||||||
|
assert "executor=['button']" in caplog.text
|
||||||
|
assert "loop=[]" in caplog.text
|
||||||
|
assert module is button_module_mock
|
||||||
|
caplog.clear()
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
**modules_without_button,
|
||||||
|
button_module_name: button_module_mock,
|
||||||
|
},
|
||||||
|
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||||
|
module = (await integration.async_get_platforms(["button"]))["button"]
|
||||||
|
|
||||||
|
# Everything is cached so there should be no logging
|
||||||
|
assert "loop=" not in caplog.text
|
||||||
|
assert "executor=" not in caplog.text
|
||||||
|
assert module is button_module_mock
|
||||||
|
caplog.clear()
|
||||||
|
|
||||||
|
modules_without_switch = {
|
||||||
|
k: v for k, v in sys.modules.items() if k not in switch_module_name
|
||||||
|
}
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{**modules_without_switch, light_module_name: light_module_mock},
|
||||||
|
clear=True,
|
||||||
|
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||||
|
modules = await integration.async_get_platforms(["button", "switch", "light"])
|
||||||
|
|
||||||
|
# The button module is already in the cache so nothing happens
|
||||||
|
# The switch module is loaded in the executor since its not in the cache
|
||||||
|
# The light module is in memory but not in the cache so its loaded in the loop
|
||||||
|
assert "['button']" not in caplog.text
|
||||||
|
assert "executor=['switch']" in caplog.text
|
||||||
|
assert "loop=['light']" in caplog.text
|
||||||
|
assert modules == {
|
||||||
|
"button": button_module_mock,
|
||||||
|
"switch": switch_module_mock,
|
||||||
|
"light": light_module_mock,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_platforms_concurrent_loads(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
enable_custom_integrations: None,
|
||||||
|
) -> None:
|
||||||
|
"""Verify async_get_platforms waits if the first load if called again.
|
||||||
|
|
||||||
|
Case is for when when a second load is called
|
||||||
|
and the first is still in progress.
|
||||||
|
"""
|
||||||
|
integration = await loader.async_get_integration(
|
||||||
|
hass, "test_package_loaded_executor"
|
||||||
|
)
|
||||||
|
assert integration.pkg_path == "custom_components.test_package_loaded_executor"
|
||||||
|
assert integration.import_executor is True
|
||||||
|
assert integration.config_flow is True
|
||||||
|
|
||||||
|
assert "test_package_loaded_executor" not in hass.config.components
|
||||||
|
assert "test_package_loaded_executor.config_flow" not in hass.config.components
|
||||||
|
await integration.async_get_component()
|
||||||
|
|
||||||
|
button_module_name = f"{integration.pkg_path}.button"
|
||||||
|
button_module_mock = MagicMock()
|
||||||
|
|
||||||
|
imports = []
|
||||||
|
start_event = threading.Event()
|
||||||
|
import_event = asyncio.Event()
|
||||||
|
|
||||||
|
def import_module(name: str) -> Any:
|
||||||
|
hass.loop.call_soon_threadsafe(import_event.set)
|
||||||
|
imports.append(name)
|
||||||
|
start_event.wait()
|
||||||
|
if name == button_module_name:
|
||||||
|
return button_module_mock
|
||||||
|
raise ImportError
|
||||||
|
|
||||||
|
modules_without_button = {
|
||||||
|
k: v
|
||||||
|
for k, v in sys.modules.items()
|
||||||
|
if k != button_module_name and k != integration.pkg_path
|
||||||
|
}
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
modules_without_button,
|
||||||
|
clear=True,
|
||||||
|
), patch("homeassistant.loader.importlib.import_module", import_module):
|
||||||
|
load_task1 = asyncio.create_task(integration.async_get_platforms(["button"]))
|
||||||
|
load_task2 = asyncio.create_task(integration.async_get_platforms(["button"]))
|
||||||
|
await import_event.wait() # make sure the import is started
|
||||||
|
assert not integration._import_futures["button"].done()
|
||||||
|
start_event.set()
|
||||||
|
load_result1 = await load_task1
|
||||||
|
load_result2 = await load_task2
|
||||||
|
assert integration._import_futures is not None
|
||||||
|
|
||||||
|
assert load_result1 == {"button": button_module_mock}
|
||||||
|
assert load_result2 == {"button": button_module_mock}
|
||||||
|
|
||||||
|
assert imports == [button_module_name]
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
"""Provide a mock button platform."""
|
@ -0,0 +1 @@
|
|||||||
|
"""Provide a mock light platform."""
|
@ -0,0 +1 @@
|
|||||||
|
"""Provide a mock switch platform."""
|
Loading…
x
Reference in New Issue
Block a user