mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Add a helper to import modules from the event loop (#113169)
* Add a helper to import modules in the event loop Replaces the one used for triggers with a more generic helper that can be reused and uses a future to avoid importing concurrently * Add a helper to import modules in the event loop Replaces the one used for triggers with a more generic helper that can be reused and uses a future to avoid importing concurrently * coverage * make sure we do not retry * coverage
This commit is contained in:
parent
4f113f256f
commit
cfe14bca8f
@ -1,9 +1,10 @@
|
||||
"""Home Assistant trigger dispatcher."""
|
||||
|
||||
import importlib
|
||||
from typing import cast
|
||||
|
||||
from homeassistant.const import CONF_PLATFORM
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||
from homeassistant.helpers.importlib import async_import_module
|
||||
from homeassistant.helpers.trigger import (
|
||||
TriggerActionType,
|
||||
TriggerInfo,
|
||||
@ -11,26 +12,15 @@ from homeassistant.helpers.trigger import (
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
DATA_TRIGGER_PLATFORMS = "homeassistant_trigger_platforms"
|
||||
|
||||
|
||||
def _get_trigger_platform(platform_name: str) -> TriggerProtocol:
|
||||
"""Get trigger platform."""
|
||||
return importlib.import_module(f"..triggers.{platform_name}", __name__)
|
||||
|
||||
|
||||
async def _async_get_trigger_platform(
|
||||
hass: HomeAssistant, platform_name: str
|
||||
) -> TriggerProtocol:
|
||||
"""Get trigger platform from cache or import it."""
|
||||
cache: dict[str, TriggerProtocol] = hass.data.setdefault(DATA_TRIGGER_PLATFORMS, {})
|
||||
if platform := cache.get(platform_name):
|
||||
return platform
|
||||
platform = await hass.async_add_import_executor_job(
|
||||
_get_trigger_platform, platform_name
|
||||
platform = await async_import_module(
|
||||
hass, f"homeassistant.components.homeassistant.triggers.{platform_name}"
|
||||
)
|
||||
cache[platform_name] = platform
|
||||
return platform
|
||||
return cast(TriggerProtocol, platform)
|
||||
|
||||
|
||||
async def async_validate_trigger_config(
|
||||
|
65
homeassistant/helpers/importlib.py
Normal file
65
homeassistant/helpers/importlib.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Helper to import modules from asyncio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DATA_IMPORT_CACHE = "import_cache"
|
||||
DATA_IMPORT_FUTURES = "import_futures"
|
||||
DATA_IMPORT_FAILURES = "import_failures"
|
||||
|
||||
|
||||
def _get_module(cache: dict[str, ModuleType], name: str) -> ModuleType:
|
||||
"""Get a module."""
|
||||
cache[name] = importlib.import_module(name)
|
||||
return cache[name]
|
||||
|
||||
|
||||
async def async_import_module(hass: HomeAssistant, name: str) -> ModuleType:
|
||||
"""Import a module or return it from the cache."""
|
||||
cache: dict[str, ModuleType] = hass.data.setdefault(DATA_IMPORT_CACHE, {})
|
||||
if module := cache.get(name):
|
||||
return module
|
||||
|
||||
failure_cache: dict[str, BaseException] = hass.data.setdefault(
|
||||
DATA_IMPORT_FAILURES, {}
|
||||
)
|
||||
if exception := failure_cache.get(name):
|
||||
raise exception
|
||||
|
||||
import_futures: dict[str, asyncio.Future[ModuleType]]
|
||||
import_futures = hass.data.setdefault(DATA_IMPORT_FUTURES, {})
|
||||
|
||||
if future := import_futures.get(name):
|
||||
return await future
|
||||
|
||||
if name in sys.modules:
|
||||
return _get_module(cache, name)
|
||||
|
||||
import_future = hass.loop.create_future()
|
||||
import_futures[name] = import_future
|
||||
try:
|
||||
module = await hass.async_add_import_executor_job(_get_module, cache, name)
|
||||
import_future.set_result(module)
|
||||
except BaseException as ex:
|
||||
failure_cache[name] = ex
|
||||
import_future.set_exception(ex)
|
||||
with suppress(BaseException):
|
||||
# Set the exception retrieved flag on the future since
|
||||
# it will never be retrieved unless there
|
||||
# are concurrent calls
|
||||
import_future.result()
|
||||
raise
|
||||
finally:
|
||||
del import_futures[name]
|
||||
|
||||
return module
|
81
tests/helpers/test_importlib.py
Normal file
81
tests/helpers/test_importlib.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Tests for the importlib helper."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import importlib
|
||||
|
||||
from tests.common import MockModule
|
||||
|
||||
|
||||
async def test_async_import_module(hass: HomeAssistant) -> None:
|
||||
"""Test importing a module."""
|
||||
mock_module = MockModule()
|
||||
with patch(
|
||||
"homeassistant.helpers.importlib.importlib.import_module",
|
||||
return_value=mock_module,
|
||||
):
|
||||
module = await importlib.async_import_module(hass, "test.module")
|
||||
|
||||
assert module is mock_module
|
||||
|
||||
|
||||
async def test_async_import_module_on_helper(hass: HomeAssistant) -> None:
|
||||
"""Test importing the importlib helper."""
|
||||
module = await importlib.async_import_module(
|
||||
hass, "homeassistant.helpers.importlib"
|
||||
)
|
||||
assert module is importlib
|
||||
module = await importlib.async_import_module(
|
||||
hass, "homeassistant.helpers.importlib"
|
||||
)
|
||||
assert module is importlib
|
||||
|
||||
|
||||
async def test_async_import_module_failures(hass: HomeAssistant) -> None:
|
||||
"""Test importing a module fails."""
|
||||
with patch(
|
||||
"homeassistant.helpers.importlib.importlib.import_module",
|
||||
side_effect=ImportError,
|
||||
), pytest.raises(ImportError):
|
||||
await importlib.async_import_module(hass, "test.module")
|
||||
|
||||
mock_module = MockModule()
|
||||
# The failure should be cached
|
||||
with pytest.raises(ImportError), patch(
|
||||
"homeassistant.helpers.importlib.importlib.import_module",
|
||||
return_value=mock_module,
|
||||
):
|
||||
await importlib.async_import_module(hass, "test.module")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eager_start", [True, False])
|
||||
async def test_async_import_module_concurrency(
|
||||
hass: HomeAssistant, eager_start: bool
|
||||
) -> None:
|
||||
"""Test importing a module with concurrency."""
|
||||
mock_module = MockModule()
|
||||
|
||||
def _mock_import(name: str, *args: Any) -> MockModule:
|
||||
time.sleep(0.1)
|
||||
return mock_module
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.importlib.importlib.import_module",
|
||||
_mock_import,
|
||||
):
|
||||
task1 = hass.async_create_task(
|
||||
importlib.async_import_module(hass, "test.module"), eager_start=eager_start
|
||||
)
|
||||
task2 = hass.async_create_task(
|
||||
importlib.async_import_module(hass, "test.module"), eager_start=eager_start
|
||||
)
|
||||
module1 = await task1
|
||||
module2 = await task2
|
||||
|
||||
assert module1 is mock_module
|
||||
assert module2 is mock_module
|
Loading…
x
Reference in New Issue
Block a user