From 5227976aa24bfb8dbe024f073e0356aef6a475c7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 3 Mar 2024 21:32:19 -1000 Subject: [PATCH] Group loading of platforms in the import executor (#112141) Co-authored-by: Paulus Schoutsen --- homeassistant/config_entries.py | 6 +- homeassistant/loader.py | 196 +++++++++++------- tests/test_loader.py | 191 ++++++++++++++++- .../test_package_loaded_executor/button.py | 1 + .../test_package_loaded_executor/light.py | 1 + .../test_package_loaded_executor/switch.py | 1 + 6 files changed, 320 insertions(+), 76 deletions(-) create mode 100644 tests/testing_config/custom_components/test_package_loaded_executor/button.py create mode 100644 tests/testing_config/custom_components/test_package_loaded_executor/light.py create mode 100644 tests/testing_config/custom_components/test_package_loaded_executor/switch.py diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index b1a4a8ce6cc..219f4ff1709 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -506,7 +506,7 @@ class ConfigEntry: if domain_is_integration: try: - await integration.async_get_platform("config_flow") + await integration.async_get_platforms(("config_flow",)) except ImportError as err: _LOGGER.error( ( @@ -1814,6 +1814,8 @@ class ConfigEntries: self, entry: ConfigEntry, platforms: Iterable[Platform | str] ) -> None: """Forward the setup of an entry to platforms.""" + integration = await loader.async_get_integration(self.hass, entry.domain) + await integration.async_get_platforms(platforms) await asyncio.gather( *( create_eager_task( @@ -2519,7 +2521,7 @@ async def _load_integration( # Make sure requirements and dependencies of component are resolved await async_process_deps_reqs(hass, hass_config, integration) try: - await integration.async_get_platform("config_flow") + await integration.async_get_platforms(("config_flow",)) except ImportError as err: _LOGGER.error( "Error occurred loading flow for integration %s: %s", diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 33ccd5615a8..f19577ac10a 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -27,6 +27,7 @@ from awesomeversion import ( import voluptuous as vol from . import generated +from .const import Platform from .core import HomeAssistant, callback from .generated.application_credentials import APPLICATION_CREDENTIALS from .generated.bluetooth import BLUETOOTH @@ -663,6 +664,12 @@ class Integration: self._component_future: asyncio.Future[ComponentProtocol] | None = None self._import_futures: dict[str, asyncio.Future[ModuleType]] = {} + cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS] + self._cache = cache + missing_platforms_cache: dict[str, ImportError] = hass.data[ + DATA_MISSING_PLATFORMS + ] + self._missing_platforms_cache = missing_platforms_cache _LOGGER.info("Loaded %s from %s", self.domain, pkg_path) @cached_property @@ -909,12 +916,14 @@ class Integration: with a dict cache which is thread-safe since importlib has appropriate locks. """ - cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS] - if self.domain in cache: - return cache[self.domain] + cache = self._cache + domain = self.domain + + if domain in cache: + return cache[domain] try: - cache[self.domain] = cast( + cache[domain] = cast( ComponentProtocol, importlib.import_module(self.pkg_path) ) except ImportError: @@ -945,75 +954,122 @@ class Integration: with suppress(ImportError): self.get_platform("config_flow") - return cache[self.domain] + return cache[domain] + + def _load_platforms(self, platform_names: Iterable[str]) -> dict[str, ModuleType]: + """Load platforms for an integration.""" + return { + platform_name: self._load_platform(platform_name) + for platform_name in platform_names + } async def async_get_platform(self, platform_name: str) -> ModuleType: """Return a platform for an integration.""" + platforms = await self.async_get_platforms([platform_name]) + return platforms[platform_name] + + async def async_get_platforms( + self, platform_names: Iterable[Platform | str] + ) -> dict[str, ModuleType]: + """Return a platforms for an integration.""" domain = self.domain - full_name = f"{self.domain}.{platform_name}" - if platform := self._get_platform_cached(full_name): - return platform - if future := self._import_futures.get(full_name): - return await future - if debug := _LOGGER.isEnabledFor(logging.DEBUG): - start = time.perf_counter() - import_future = self.hass.loop.create_future() - self._import_futures[full_name] = import_future - load_executor = ( - self.import_executor - and domain not in self.hass.config.components - and f"{self.pkg_path}.{domain}" not in sys.modules - ) - try: - if load_executor: - try: - platform = await self.hass.async_add_import_executor_job( - self._load_platform, platform_name - ) - except ImportError as ex: - _LOGGER.debug( - "Failed to import %s in executor", domain, exc_info=ex - ) - load_executor = False - # If importing in the executor deadlocks because there is a circular - # dependency, we fall back to the event loop. - platform = self._load_platform(platform_name) + platforms: dict[str, ModuleType] = {} + + load_executor_platforms: list[str] = [] + load_event_loop_platforms: list[str] = [] + in_progress_imports: dict[str, asyncio.Future[ModuleType]] = {} + import_futures: list[tuple[str, asyncio.Future[ModuleType]]] = [] + + for platform_name in platform_names: + full_name = f"{domain}.{platform_name}" + if platform := self._get_platform_cached(full_name): + platforms[platform_name] = platform + continue + + # Another call to async_get_platforms is already importing this platform + if future := self._import_futures.get(platform_name): + in_progress_imports[platform_name] = future + continue + + if ( + self.import_executor + and full_name not in self.hass.config.components + and f"{self.pkg_path}.{platform_name}" not in sys.modules + ): + load_executor_platforms.append(platform_name) else: - platform = self._load_platform(platform_name) - import_future.set_result(platform) - except BaseException as ex: - import_future.set_exception(ex) - with suppress(BaseException): - # Clear the exception retrieved flag on the future since - # it will never be retrieved unless there - # are concurrent calls to async_get_platform - import_future.result() - raise - finally: - self._import_futures.pop(full_name) + load_event_loop_platforms.append(platform_name) - if debug: - _LOGGER.debug( - "Importing platform %s took %.2fs (loaded_executor=%s)", - full_name, - time.perf_counter() - start, - load_executor, - ) + import_future = self.hass.loop.create_future() + self._import_futures[platform_name] = import_future + import_futures.append((platform_name, import_future)) - return platform + if load_executor_platforms or load_event_loop_platforms: + if debug := _LOGGER.isEnabledFor(logging.DEBUG): + start = time.perf_counter() + + try: + if load_executor_platforms: + try: + platforms.update( + await self.hass.async_add_import_executor_job( + self._load_platforms, platform_names + ) + ) + except ImportError as ex: + _LOGGER.debug( + "Failed to import %s platforms %s in executor", + domain, + load_executor_platforms, + exc_info=ex, + ) + # If importing in the executor deadlocks because there is a circular + # dependency, we fall back to the event loop. + load_event_loop_platforms.extend(load_executor_platforms) + + if load_event_loop_platforms: + platforms.update(self._load_platforms(platform_names)) + + for platform_name, import_future in import_futures: + import_future.set_result(platforms[platform_name]) + + except BaseException as ex: + for _, import_future in import_futures: + import_future.set_exception(ex) + with suppress(BaseException): + # Set the exception retrieved flag on the future since + # it will never be retrieved unless there + # are concurrent calls to async_get_platforms + import_future.result() + raise + + finally: + for platform_name, _ in import_futures: + self._import_futures.pop(platform_name) + + if debug: + _LOGGER.debug( + "Importing platforms for %s executor=%s loop=%s took %.2fs", + domain, + load_executor_platforms, + load_event_loop_platforms, + time.perf_counter() - start, + ) + + if in_progress_imports: + for platform_name, future in in_progress_imports.items(): + platforms[platform_name] = await future + + return platforms def _get_platform_cached(self, full_name: str) -> ModuleType | None: """Return a platform for an integration from cache.""" - cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS] - if full_name in cache: - return cache[full_name] - - missing_platforms_cache: dict[str, ImportError] = self.hass.data[ - DATA_MISSING_PLATFORMS - ] - if full_name in missing_platforms_cache: - raise missing_platforms_cache[full_name] - + if full_name in self._cache: + # the cache is either a ModuleType or a ComponentProtocol + # but we only care about the ModuleType here + return self._cache[full_name] # type: ignore[return-value] + if full_name in self._missing_platforms_cache: + raise self._missing_platforms_cache[full_name] return None def get_platform(self, platform_name: str) -> ModuleType: @@ -1033,14 +1089,11 @@ class Integration: has been called for the integration or the integration failed to load. """ full_name = f"{self.domain}.{platform_name}" - - cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS] + cache = self._cache if full_name in cache: return True - missing_platforms_cache: dict[str, ImportError] - missing_platforms_cache = self.hass.data[DATA_MISSING_PLATFORMS] - if full_name in missing_platforms_cache: + if full_name in self._missing_platforms_cache: return False if not (component := cache.get(self.domain)) or not ( @@ -1056,7 +1109,7 @@ class Integration: f"Platform {full_name} not found", name=f"{self.pkg_path}.{platform_name}", ) - missing_platforms_cache[full_name] = exc + self._missing_platforms_cache[full_name] = exc return False def _load_platform(self, platform_name: str) -> ModuleType: @@ -1077,10 +1130,7 @@ class Integration: if self.domain in cache: # If the domain is loaded, cache that the platform # does not exist so we do not try to load it again - missing_platforms_cache: dict[str, ImportError] = self.hass.data[ - DATA_MISSING_PLATFORMS - ] - missing_platforms_cache[full_name] = ex + self._missing_platforms_cache[full_name] = ex raise except RuntimeError as err: # _DeadlockError inherits from RuntimeError diff --git a/tests/test_loader.py b/tests/test_loader.py index 8400adca5c4..fdbc457dfe0 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -278,6 +278,41 @@ async def test_async_get_platform_caches_failures_when_component_loaded( assert await integration.async_get_platform("light") == hue_light +async def test_async_get_platforms_caches_failures_when_component_loaded( + hass: HomeAssistant, +) -> None: + """Test async_get_platforms cache failures only when the component is loaded.""" + integration = await loader.async_get_integration(hass, "hue") + + with pytest.raises(ImportError), patch( + "homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom") + ): + assert integration.get_component() == hue + + with pytest.raises(ImportError), patch( + "homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom") + ): + assert await integration.async_get_platforms(["light"]) == {"light": hue_light} + + # Hue is not loaded so we should still hit the import_module path + with pytest.raises(ImportError), patch( + "homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom") + ): + assert await integration.async_get_platforms(["light"]) == {"light": hue_light} + + assert integration.get_component() == hue + + # Hue is loaded so we should cache the import_module failure now + with pytest.raises(ImportError), patch( + "homeassistant.loader.importlib.import_module", side_effect=ImportError("Boom") + ): + assert await integration.async_get_platforms(["light"]) == {"light": hue_light} + + # Hue is loaded and the last call should have cached the import_module failure + with pytest.raises(ImportError): + assert await integration.async_get_platforms(["light"]) == {"light": hue_light} + + async def test_get_integration_legacy( hass: HomeAssistant, enable_custom_integrations: None ) -> None: @@ -1302,7 +1337,9 @@ async def test_async_get_platform_deadlock_fallback( "Detected deadlock trying to import homeassistant.components.executor_import" in caplog.text ) - assert "loaded_executor=False" in caplog.text + # We should have tried both the executor and loop + assert "executor=['config_flow']" in caplog.text + assert "loop=['config_flow']" in caplog.text assert module is module_mock @@ -1390,3 +1427,155 @@ async def test_platform_exists( assert platform.MAGIC == 1 assert integration.platform_exists("group") is True + + +async def test_async_get_platforms_loads_loop_if_already_in_sys_modules( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + enable_custom_integrations: None, +) -> None: + """Verify async_get_platforms does not create an executor job. + + Case is for when the module is already in sys.modules. + """ + integration = await loader.async_get_integration( + hass, "test_package_loaded_executor" + ) + assert integration.pkg_path == "custom_components.test_package_loaded_executor" + assert integration.import_executor is True + assert integration.config_flow is True + + assert "test_package_loaded_executor" not in hass.config.components + assert "test_package_loaded_executor.config_flow" not in hass.config.components + await integration.async_get_component() + + button_module_name = f"{integration.pkg_path}.button" + switch_module_name = f"{integration.pkg_path}.switch" + light_module_name = f"{integration.pkg_path}.light" + button_module_mock = MagicMock() + switch_module_mock = MagicMock() + light_module_mock = MagicMock() + + def import_module(name: str) -> Any: + if name == button_module_name: + return button_module_mock + if name == switch_module_name: + return switch_module_mock + if name == light_module_name: + return light_module_mock + raise ImportError + + modules_without_button = { + k: v for k, v in sys.modules.items() if k != button_module_name + } + with patch.dict( + "sys.modules", + modules_without_button, + clear=True, + ), patch("homeassistant.loader.importlib.import_module", import_module): + module = (await integration.async_get_platforms(["button"]))["button"] + + # The button module is missing so we should load + # in the executor + assert "executor=['button']" in caplog.text + assert "loop=[]" in caplog.text + assert module is button_module_mock + caplog.clear() + + with patch.dict( + "sys.modules", + { + **modules_without_button, + button_module_name: button_module_mock, + }, + ), patch("homeassistant.loader.importlib.import_module", import_module): + module = (await integration.async_get_platforms(["button"]))["button"] + + # Everything is cached so there should be no logging + assert "loop=" not in caplog.text + assert "executor=" not in caplog.text + assert module is button_module_mock + caplog.clear() + + modules_without_switch = { + k: v for k, v in sys.modules.items() if k not in switch_module_name + } + with patch.dict( + "sys.modules", + {**modules_without_switch, light_module_name: light_module_mock}, + clear=True, + ), patch("homeassistant.loader.importlib.import_module", import_module): + modules = await integration.async_get_platforms(["button", "switch", "light"]) + + # The button module is already in the cache so nothing happens + # The switch module is loaded in the executor since its not in the cache + # The light module is in memory but not in the cache so its loaded in the loop + assert "['button']" not in caplog.text + assert "executor=['switch']" in caplog.text + assert "loop=['light']" in caplog.text + assert modules == { + "button": button_module_mock, + "switch": switch_module_mock, + "light": light_module_mock, + } + + +async def test_async_get_platforms_concurrent_loads( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + enable_custom_integrations: None, +) -> None: + """Verify async_get_platforms waits if the first load if called again. + + Case is for when when a second load is called + and the first is still in progress. + """ + integration = await loader.async_get_integration( + hass, "test_package_loaded_executor" + ) + assert integration.pkg_path == "custom_components.test_package_loaded_executor" + assert integration.import_executor is True + assert integration.config_flow is True + + assert "test_package_loaded_executor" not in hass.config.components + assert "test_package_loaded_executor.config_flow" not in hass.config.components + await integration.async_get_component() + + button_module_name = f"{integration.pkg_path}.button" + button_module_mock = MagicMock() + + imports = [] + start_event = threading.Event() + import_event = asyncio.Event() + + def import_module(name: str) -> Any: + hass.loop.call_soon_threadsafe(import_event.set) + imports.append(name) + start_event.wait() + if name == button_module_name: + return button_module_mock + raise ImportError + + modules_without_button = { + k: v + for k, v in sys.modules.items() + if k != button_module_name and k != integration.pkg_path + } + with patch.dict( + "sys.modules", + modules_without_button, + clear=True, + ), patch("homeassistant.loader.importlib.import_module", import_module): + load_task1 = asyncio.create_task(integration.async_get_platforms(["button"])) + load_task2 = asyncio.create_task(integration.async_get_platforms(["button"])) + await import_event.wait() # make sure the import is started + assert not integration._import_futures["button"].done() + start_event.set() + load_result1 = await load_task1 + load_result2 = await load_task2 + assert integration._import_futures is not None + + assert load_result1 == {"button": button_module_mock} + assert load_result2 == {"button": button_module_mock} + + assert imports == [button_module_name] diff --git a/tests/testing_config/custom_components/test_package_loaded_executor/button.py b/tests/testing_config/custom_components/test_package_loaded_executor/button.py new file mode 100644 index 00000000000..0157551af84 --- /dev/null +++ b/tests/testing_config/custom_components/test_package_loaded_executor/button.py @@ -0,0 +1 @@ +"""Provide a mock button platform.""" diff --git a/tests/testing_config/custom_components/test_package_loaded_executor/light.py b/tests/testing_config/custom_components/test_package_loaded_executor/light.py new file mode 100644 index 00000000000..0f1e5f1a631 --- /dev/null +++ b/tests/testing_config/custom_components/test_package_loaded_executor/light.py @@ -0,0 +1 @@ +"""Provide a mock light platform.""" diff --git a/tests/testing_config/custom_components/test_package_loaded_executor/switch.py b/tests/testing_config/custom_components/test_package_loaded_executor/switch.py new file mode 100644 index 00000000000..134235622f3 --- /dev/null +++ b/tests/testing_config/custom_components/test_package_loaded_executor/switch.py @@ -0,0 +1 @@ +"""Provide a mock switch platform."""