Add a helper to import modules from the event loop (#113169)

* Add a helper to import modules in the event loop

Replaces the one used for triggers with a more generic helper
that can be reused and uses a future to avoid importing concurrently

* Add a helper to import modules in the event loop

Replaces the one used for triggers with a more generic helper
that can be reused and uses a future to avoid importing concurrently

* coverage

* make sure we do not retry

* coverage
This commit is contained in:
J. Nick Koston 2024-03-13 18:26:33 -10:00 committed by GitHub
parent 4f113f256f
commit cfe14bca8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 151 additions and 15 deletions

View File

@ -1,9 +1,10 @@
"""Home Assistant trigger dispatcher."""
import importlib
from typing import cast
from homeassistant.const import CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.importlib import async_import_module
from homeassistant.helpers.trigger import (
TriggerActionType,
TriggerInfo,
@ -11,26 +12,15 @@ from homeassistant.helpers.trigger import (
)
from homeassistant.helpers.typing import ConfigType
DATA_TRIGGER_PLATFORMS = "homeassistant_trigger_platforms"
def _get_trigger_platform(platform_name: str) -> TriggerProtocol:
"""Get trigger platform."""
return importlib.import_module(f"..triggers.{platform_name}", __name__)
async def _async_get_trigger_platform(
hass: HomeAssistant, platform_name: str
) -> TriggerProtocol:
"""Get trigger platform from cache or import it."""
cache: dict[str, TriggerProtocol] = hass.data.setdefault(DATA_TRIGGER_PLATFORMS, {})
if platform := cache.get(platform_name):
return platform
platform = await hass.async_add_import_executor_job(
_get_trigger_platform, platform_name
platform = await async_import_module(
hass, f"homeassistant.components.homeassistant.triggers.{platform_name}"
)
cache[platform_name] = platform
return platform
return cast(TriggerProtocol, platform)
async def async_validate_trigger_config(

View File

@ -0,0 +1,65 @@
"""Helper to import modules from asyncio."""
from __future__ import annotations
import asyncio
from contextlib import suppress
import importlib
import logging
import sys
from types import ModuleType
from homeassistant.core import HomeAssistant
_LOGGER = logging.getLogger(__name__)
DATA_IMPORT_CACHE = "import_cache"
DATA_IMPORT_FUTURES = "import_futures"
DATA_IMPORT_FAILURES = "import_failures"
def _get_module(cache: dict[str, ModuleType], name: str) -> ModuleType:
"""Get a module."""
cache[name] = importlib.import_module(name)
return cache[name]
async def async_import_module(hass: HomeAssistant, name: str) -> ModuleType:
"""Import a module or return it from the cache."""
cache: dict[str, ModuleType] = hass.data.setdefault(DATA_IMPORT_CACHE, {})
if module := cache.get(name):
return module
failure_cache: dict[str, BaseException] = hass.data.setdefault(
DATA_IMPORT_FAILURES, {}
)
if exception := failure_cache.get(name):
raise exception
import_futures: dict[str, asyncio.Future[ModuleType]]
import_futures = hass.data.setdefault(DATA_IMPORT_FUTURES, {})
if future := import_futures.get(name):
return await future
if name in sys.modules:
return _get_module(cache, name)
import_future = hass.loop.create_future()
import_futures[name] = import_future
try:
module = await hass.async_add_import_executor_job(_get_module, cache, name)
import_future.set_result(module)
except BaseException as ex:
failure_cache[name] = ex
import_future.set_exception(ex)
with suppress(BaseException):
# Set the exception retrieved flag on the future since
# it will never be retrieved unless there
# are concurrent calls
import_future.result()
raise
finally:
del import_futures[name]
return module

View File

@ -0,0 +1,81 @@
"""Tests for the importlib helper."""
import time
from typing import Any
from unittest.mock import patch
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.helpers import importlib
from tests.common import MockModule
async def test_async_import_module(hass: HomeAssistant) -> None:
"""Test importing a module."""
mock_module = MockModule()
with patch(
"homeassistant.helpers.importlib.importlib.import_module",
return_value=mock_module,
):
module = await importlib.async_import_module(hass, "test.module")
assert module is mock_module
async def test_async_import_module_on_helper(hass: HomeAssistant) -> None:
"""Test importing the importlib helper."""
module = await importlib.async_import_module(
hass, "homeassistant.helpers.importlib"
)
assert module is importlib
module = await importlib.async_import_module(
hass, "homeassistant.helpers.importlib"
)
assert module is importlib
async def test_async_import_module_failures(hass: HomeAssistant) -> None:
"""Test importing a module fails."""
with patch(
"homeassistant.helpers.importlib.importlib.import_module",
side_effect=ImportError,
), pytest.raises(ImportError):
await importlib.async_import_module(hass, "test.module")
mock_module = MockModule()
# The failure should be cached
with pytest.raises(ImportError), patch(
"homeassistant.helpers.importlib.importlib.import_module",
return_value=mock_module,
):
await importlib.async_import_module(hass, "test.module")
@pytest.mark.parametrize("eager_start", [True, False])
async def test_async_import_module_concurrency(
hass: HomeAssistant, eager_start: bool
) -> None:
"""Test importing a module with concurrency."""
mock_module = MockModule()
def _mock_import(name: str, *args: Any) -> MockModule:
time.sleep(0.1)
return mock_module
with patch(
"homeassistant.helpers.importlib.importlib.import_module",
_mock_import,
):
task1 = hass.async_create_task(
importlib.async_import_module(hass, "test.module"), eager_start=eager_start
)
task2 = hass.async_create_task(
importlib.async_import_module(hass, "test.module"), eager_start=eager_start
)
module1 = await task1
module2 = await task2
assert module1 is mock_module
assert module2 is mock_module