From cfe14bca8f590d5b26f6dce05760b72949370e54 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 13 Mar 2024 18:26:33 -1000 Subject: [PATCH] Add a helper to import modules from the event loop (#113169) * Add a helper to import modules in the event loop Replaces the one used for triggers with a more generic helper that can be reused and uses a future to avoid importing concurrently * Add a helper to import modules in the event loop Replaces the one used for triggers with a more generic helper that can be reused and uses a future to avoid importing concurrently * coverage * make sure we do not retry * coverage --- .../components/homeassistant/trigger.py | 20 ++--- homeassistant/helpers/importlib.py | 65 +++++++++++++++ tests/helpers/test_importlib.py | 81 +++++++++++++++++++ 3 files changed, 151 insertions(+), 15 deletions(-) create mode 100644 homeassistant/helpers/importlib.py create mode 100644 tests/helpers/test_importlib.py diff --git a/homeassistant/components/homeassistant/trigger.py b/homeassistant/components/homeassistant/trigger.py index 74a96fce784..495cd07502a 100644 --- a/homeassistant/components/homeassistant/trigger.py +++ b/homeassistant/components/homeassistant/trigger.py @@ -1,9 +1,10 @@ """Home Assistant trigger dispatcher.""" -import importlib +from typing import cast from homeassistant.const import CONF_PLATFORM from homeassistant.core import CALLBACK_TYPE, HomeAssistant +from homeassistant.helpers.importlib import async_import_module from homeassistant.helpers.trigger import ( TriggerActionType, TriggerInfo, @@ -11,26 +12,15 @@ from homeassistant.helpers.trigger import ( ) from homeassistant.helpers.typing import ConfigType -DATA_TRIGGER_PLATFORMS = "homeassistant_trigger_platforms" - - -def _get_trigger_platform(platform_name: str) -> TriggerProtocol: - """Get trigger platform.""" - return importlib.import_module(f"..triggers.{platform_name}", __name__) - async def _async_get_trigger_platform( hass: HomeAssistant, platform_name: str ) -> TriggerProtocol: """Get trigger platform from cache or import it.""" - cache: dict[str, TriggerProtocol] = hass.data.setdefault(DATA_TRIGGER_PLATFORMS, {}) - if platform := cache.get(platform_name): - return platform - platform = await hass.async_add_import_executor_job( - _get_trigger_platform, platform_name + platform = await async_import_module( + hass, f"homeassistant.components.homeassistant.triggers.{platform_name}" ) - cache[platform_name] = platform - return platform + return cast(TriggerProtocol, platform) async def async_validate_trigger_config( diff --git a/homeassistant/helpers/importlib.py b/homeassistant/helpers/importlib.py new file mode 100644 index 00000000000..00af75f6d8e --- /dev/null +++ b/homeassistant/helpers/importlib.py @@ -0,0 +1,65 @@ +"""Helper to import modules from asyncio.""" + +from __future__ import annotations + +import asyncio +from contextlib import suppress +import importlib +import logging +import sys +from types import ModuleType + +from homeassistant.core import HomeAssistant + +_LOGGER = logging.getLogger(__name__) + +DATA_IMPORT_CACHE = "import_cache" +DATA_IMPORT_FUTURES = "import_futures" +DATA_IMPORT_FAILURES = "import_failures" + + +def _get_module(cache: dict[str, ModuleType], name: str) -> ModuleType: + """Get a module.""" + cache[name] = importlib.import_module(name) + return cache[name] + + +async def async_import_module(hass: HomeAssistant, name: str) -> ModuleType: + """Import a module or return it from the cache.""" + cache: dict[str, ModuleType] = hass.data.setdefault(DATA_IMPORT_CACHE, {}) + if module := cache.get(name): + return module + + failure_cache: dict[str, BaseException] = hass.data.setdefault( + DATA_IMPORT_FAILURES, {} + ) + if exception := failure_cache.get(name): + raise exception + + import_futures: dict[str, asyncio.Future[ModuleType]] + import_futures = hass.data.setdefault(DATA_IMPORT_FUTURES, {}) + + if future := import_futures.get(name): + return await future + + if name in sys.modules: + return _get_module(cache, name) + + import_future = hass.loop.create_future() + import_futures[name] = import_future + try: + module = await hass.async_add_import_executor_job(_get_module, cache, name) + import_future.set_result(module) + except BaseException as ex: + failure_cache[name] = ex + import_future.set_exception(ex) + with suppress(BaseException): + # Set the exception retrieved flag on the future since + # it will never be retrieved unless there + # are concurrent calls + import_future.result() + raise + finally: + del import_futures[name] + + return module diff --git a/tests/helpers/test_importlib.py b/tests/helpers/test_importlib.py new file mode 100644 index 00000000000..7f89018ded2 --- /dev/null +++ b/tests/helpers/test_importlib.py @@ -0,0 +1,81 @@ +"""Tests for the importlib helper.""" + +import time +from typing import Any +from unittest.mock import patch + +import pytest + +from homeassistant.core import HomeAssistant +from homeassistant.helpers import importlib + +from tests.common import MockModule + + +async def test_async_import_module(hass: HomeAssistant) -> None: + """Test importing a module.""" + mock_module = MockModule() + with patch( + "homeassistant.helpers.importlib.importlib.import_module", + return_value=mock_module, + ): + module = await importlib.async_import_module(hass, "test.module") + + assert module is mock_module + + +async def test_async_import_module_on_helper(hass: HomeAssistant) -> None: + """Test importing the importlib helper.""" + module = await importlib.async_import_module( + hass, "homeassistant.helpers.importlib" + ) + assert module is importlib + module = await importlib.async_import_module( + hass, "homeassistant.helpers.importlib" + ) + assert module is importlib + + +async def test_async_import_module_failures(hass: HomeAssistant) -> None: + """Test importing a module fails.""" + with patch( + "homeassistant.helpers.importlib.importlib.import_module", + side_effect=ImportError, + ), pytest.raises(ImportError): + await importlib.async_import_module(hass, "test.module") + + mock_module = MockModule() + # The failure should be cached + with pytest.raises(ImportError), patch( + "homeassistant.helpers.importlib.importlib.import_module", + return_value=mock_module, + ): + await importlib.async_import_module(hass, "test.module") + + +@pytest.mark.parametrize("eager_start", [True, False]) +async def test_async_import_module_concurrency( + hass: HomeAssistant, eager_start: bool +) -> None: + """Test importing a module with concurrency.""" + mock_module = MockModule() + + def _mock_import(name: str, *args: Any) -> MockModule: + time.sleep(0.1) + return mock_module + + with patch( + "homeassistant.helpers.importlib.importlib.import_module", + _mock_import, + ): + task1 = hass.async_create_task( + importlib.async_import_module(hass, "test.module"), eager_start=eager_start + ) + task2 = hass.async_create_task( + importlib.async_import_module(hass, "test.module"), eager_start=eager_start + ) + module1 = await task1 + module2 = await task2 + + assert module1 is mock_module + assert module2 is mock_module