mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Add a helper to import modules from the event loop (#113169)
* Add a helper to import modules in the event loop Replaces the one used for triggers with a more generic helper that can be reused and uses a future to avoid importing concurrently * Add a helper to import modules in the event loop Replaces the one used for triggers with a more generic helper that can be reused and uses a future to avoid importing concurrently * coverage * make sure we do not retry * coverage
This commit is contained in:
parent
4f113f256f
commit
cfe14bca8f
@ -1,9 +1,10 @@
|
|||||||
"""Home Assistant trigger dispatcher."""
|
"""Home Assistant trigger dispatcher."""
|
||||||
|
|
||||||
import importlib
|
from typing import cast
|
||||||
|
|
||||||
from homeassistant.const import CONF_PLATFORM
|
from homeassistant.const import CONF_PLATFORM
|
||||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||||
|
from homeassistant.helpers.importlib import async_import_module
|
||||||
from homeassistant.helpers.trigger import (
|
from homeassistant.helpers.trigger import (
|
||||||
TriggerActionType,
|
TriggerActionType,
|
||||||
TriggerInfo,
|
TriggerInfo,
|
||||||
@ -11,26 +12,15 @@ from homeassistant.helpers.trigger import (
|
|||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
DATA_TRIGGER_PLATFORMS = "homeassistant_trigger_platforms"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_trigger_platform(platform_name: str) -> TriggerProtocol:
|
|
||||||
"""Get trigger platform."""
|
|
||||||
return importlib.import_module(f"..triggers.{platform_name}", __name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_get_trigger_platform(
|
async def _async_get_trigger_platform(
|
||||||
hass: HomeAssistant, platform_name: str
|
hass: HomeAssistant, platform_name: str
|
||||||
) -> TriggerProtocol:
|
) -> TriggerProtocol:
|
||||||
"""Get trigger platform from cache or import it."""
|
"""Get trigger platform from cache or import it."""
|
||||||
cache: dict[str, TriggerProtocol] = hass.data.setdefault(DATA_TRIGGER_PLATFORMS, {})
|
platform = await async_import_module(
|
||||||
if platform := cache.get(platform_name):
|
hass, f"homeassistant.components.homeassistant.triggers.{platform_name}"
|
||||||
return platform
|
|
||||||
platform = await hass.async_add_import_executor_job(
|
|
||||||
_get_trigger_platform, platform_name
|
|
||||||
)
|
)
|
||||||
cache[platform_name] = platform
|
return cast(TriggerProtocol, platform)
|
||||||
return platform
|
|
||||||
|
|
||||||
|
|
||||||
async def async_validate_trigger_config(
|
async def async_validate_trigger_config(
|
||||||
|
65
homeassistant/helpers/importlib.py
Normal file
65
homeassistant/helpers/importlib.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""Helper to import modules from asyncio."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DATA_IMPORT_CACHE = "import_cache"
|
||||||
|
DATA_IMPORT_FUTURES = "import_futures"
|
||||||
|
DATA_IMPORT_FAILURES = "import_failures"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_module(cache: dict[str, ModuleType], name: str) -> ModuleType:
|
||||||
|
"""Get a module."""
|
||||||
|
cache[name] = importlib.import_module(name)
|
||||||
|
return cache[name]
|
||||||
|
|
||||||
|
|
||||||
|
async def async_import_module(hass: HomeAssistant, name: str) -> ModuleType:
|
||||||
|
"""Import a module or return it from the cache."""
|
||||||
|
cache: dict[str, ModuleType] = hass.data.setdefault(DATA_IMPORT_CACHE, {})
|
||||||
|
if module := cache.get(name):
|
||||||
|
return module
|
||||||
|
|
||||||
|
failure_cache: dict[str, BaseException] = hass.data.setdefault(
|
||||||
|
DATA_IMPORT_FAILURES, {}
|
||||||
|
)
|
||||||
|
if exception := failure_cache.get(name):
|
||||||
|
raise exception
|
||||||
|
|
||||||
|
import_futures: dict[str, asyncio.Future[ModuleType]]
|
||||||
|
import_futures = hass.data.setdefault(DATA_IMPORT_FUTURES, {})
|
||||||
|
|
||||||
|
if future := import_futures.get(name):
|
||||||
|
return await future
|
||||||
|
|
||||||
|
if name in sys.modules:
|
||||||
|
return _get_module(cache, name)
|
||||||
|
|
||||||
|
import_future = hass.loop.create_future()
|
||||||
|
import_futures[name] = import_future
|
||||||
|
try:
|
||||||
|
module = await hass.async_add_import_executor_job(_get_module, cache, name)
|
||||||
|
import_future.set_result(module)
|
||||||
|
except BaseException as ex:
|
||||||
|
failure_cache[name] = ex
|
||||||
|
import_future.set_exception(ex)
|
||||||
|
with suppress(BaseException):
|
||||||
|
# Set the exception retrieved flag on the future since
|
||||||
|
# it will never be retrieved unless there
|
||||||
|
# are concurrent calls
|
||||||
|
import_future.result()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
del import_futures[name]
|
||||||
|
|
||||||
|
return module
|
81
tests/helpers/test_importlib.py
Normal file
81
tests/helpers/test_importlib.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Tests for the importlib helper."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import importlib
|
||||||
|
|
||||||
|
from tests.common import MockModule
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_import_module(hass: HomeAssistant) -> None:
|
||||||
|
"""Test importing a module."""
|
||||||
|
mock_module = MockModule()
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.importlib.importlib.import_module",
|
||||||
|
return_value=mock_module,
|
||||||
|
):
|
||||||
|
module = await importlib.async_import_module(hass, "test.module")
|
||||||
|
|
||||||
|
assert module is mock_module
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_import_module_on_helper(hass: HomeAssistant) -> None:
|
||||||
|
"""Test importing the importlib helper."""
|
||||||
|
module = await importlib.async_import_module(
|
||||||
|
hass, "homeassistant.helpers.importlib"
|
||||||
|
)
|
||||||
|
assert module is importlib
|
||||||
|
module = await importlib.async_import_module(
|
||||||
|
hass, "homeassistant.helpers.importlib"
|
||||||
|
)
|
||||||
|
assert module is importlib
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_import_module_failures(hass: HomeAssistant) -> None:
|
||||||
|
"""Test importing a module fails."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.importlib.importlib.import_module",
|
||||||
|
side_effect=ImportError,
|
||||||
|
), pytest.raises(ImportError):
|
||||||
|
await importlib.async_import_module(hass, "test.module")
|
||||||
|
|
||||||
|
mock_module = MockModule()
|
||||||
|
# The failure should be cached
|
||||||
|
with pytest.raises(ImportError), patch(
|
||||||
|
"homeassistant.helpers.importlib.importlib.import_module",
|
||||||
|
return_value=mock_module,
|
||||||
|
):
|
||||||
|
await importlib.async_import_module(hass, "test.module")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("eager_start", [True, False])
|
||||||
|
async def test_async_import_module_concurrency(
|
||||||
|
hass: HomeAssistant, eager_start: bool
|
||||||
|
) -> None:
|
||||||
|
"""Test importing a module with concurrency."""
|
||||||
|
mock_module = MockModule()
|
||||||
|
|
||||||
|
def _mock_import(name: str, *args: Any) -> MockModule:
|
||||||
|
time.sleep(0.1)
|
||||||
|
return mock_module
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.importlib.importlib.import_module",
|
||||||
|
_mock_import,
|
||||||
|
):
|
||||||
|
task1 = hass.async_create_task(
|
||||||
|
importlib.async_import_module(hass, "test.module"), eager_start=eager_start
|
||||||
|
)
|
||||||
|
task2 = hass.async_create_task(
|
||||||
|
importlib.async_import_module(hass, "test.module"), eager_start=eager_start
|
||||||
|
)
|
||||||
|
module1 = await task1
|
||||||
|
module2 = await task2
|
||||||
|
|
||||||
|
assert module1 is mock_module
|
||||||
|
assert module2 is mock_module
|
Loading…
x
Reference in New Issue
Block a user