diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 5d19249e37b..b8ec5987142 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -35,7 +35,6 @@ from .setup import ( async_setup_component, ) from .util import dt as dt_util -from .util.async_ import gather_with_concurrency from .util.logging import async_activate_log_queue_handler from .util.package import async_get_user_site, is_virtual_env @@ -479,14 +478,9 @@ async def _async_set_up_integrations( integrations_to_process = [ int_or_exc - for int_or_exc in await gather_with_concurrency( - loader.MAX_LOAD_CONCURRENTLY, - *( - loader.async_get_integration(hass, domain) - for domain in old_to_resolve - ), - return_exceptions=True, - ) + for int_or_exc in ( + await loader.async_get_integrations(hass, old_to_resolve) + ).values() if isinstance(int_or_exc, loader.Integration) ] resolve_dependencies_tasks = [ diff --git a/homeassistant/components/analytics/analytics.py b/homeassistant/components/analytics/analytics.py index 5bb0368b021..1a696b0c206 100644 --- a/homeassistant/components/analytics/analytics.py +++ b/homeassistant/components/analytics/analytics.py @@ -18,7 +18,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.storage import Store from homeassistant.helpers.system_info import async_get_system_info -from homeassistant.loader import IntegrationNotFound, async_get_integration +from homeassistant.loader import IntegrationNotFound, async_get_integrations from homeassistant.setup import async_get_loaded_integrations from .const import ( @@ -182,15 +182,9 @@ class Analytics: if self.preferences.get(ATTR_USAGE, False) or self.preferences.get( ATTR_STATISTICS, False ): - configured_integrations = await asyncio.gather( - *( - async_get_integration(self.hass, domain) - for domain in async_get_loaded_integrations(self.hass) - ), - return_exceptions=True, - ) - - for integration in configured_integrations: + domains = async_get_loaded_integrations(self.hass) + configured_integrations = await async_get_integrations(self.hass, domains) + for integration in configured_integrations.values(): if isinstance(integration, IntegrationNotFound): continue diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index b7e7a353633..6c18fd96627 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -1,7 +1,6 @@ """Commands part of Websocket API.""" from __future__ import annotations -import asyncio from collections.abc import Callable import datetime as dt import json @@ -32,7 +31,12 @@ from homeassistant.helpers.event import ( ) from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder from homeassistant.helpers.service import async_get_all_descriptions -from homeassistant.loader import IntegrationNotFound, async_get_integration +from homeassistant.loader import ( + Integration, + IntegrationNotFound, + async_get_integration, + async_get_integrations, +) from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations from homeassistant.util.json import ( find_paths_unserializable_data, @@ -372,9 +376,13 @@ async def handle_manifest_list( wanted_integrations = msg.get("integrations") if wanted_integrations is None: wanted_integrations = async_get_loaded_integrations(hass) - integrations = await asyncio.gather( - *(async_get_integration(hass, domain) for domain in wanted_integrations) - ) + + ints_or_excs = await async_get_integrations(hass, wanted_integrations) + integrations: list[Integration] = [] + for int_or_exc in ints_or_excs.values(): + if isinstance(int_or_exc, Exception): + raise int_or_exc + integrations.append(int_or_exc) connection.send_result( msg["id"], [integration.manifest for integration in integrations] ) @@ -706,12 +714,12 @@ async def handle_supported_brands( ) -> None: """Handle supported brands command.""" data = {} - for integration in await asyncio.gather( - *[ - async_get_integration(hass, integration) - for integration in supported_brands.HAS_SUPPORTED_BRANDS - ] - ): - data[integration.domain] = integration.manifest["supported_brands"] + ints_or_excs = await async_get_integrations( + hass, supported_brands.HAS_SUPPORTED_BRANDS + ) + for int_or_exc in ints_or_excs.values(): + if isinstance(int_or_exc, Exception): + raise int_or_exc + data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"] connection.send_result(msg["id"], data) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index bc3451c24c0..cf7fb3b2304 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -31,13 +31,7 @@ from homeassistant.exceptions import ( Unauthorized, UnknownUser, ) -from homeassistant.loader import ( - MAX_LOAD_CONCURRENTLY, - Integration, - async_get_integration, - bind_hass, -) -from homeassistant.util.async_ import gather_with_concurrency +from homeassistant.loader import Integration, async_get_integrations, bind_hass from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml.loader import JSON_TYPE @@ -467,10 +461,12 @@ async def async_get_all_descriptions( loaded = {} if missing: - integrations = await gather_with_concurrency( - MAX_LOAD_CONCURRENTLY, - *(async_get_integration(hass, domain) for domain in missing), - ) + ints_or_excs = await async_get_integrations(hass, missing) + integrations = [ + int_or_exc + for int_or_exc in ints_or_excs.values() + if isinstance(int_or_exc, Integration) + ] contents = await hass.async_add_executor_job( _load_services_files, hass, integrations diff --git a/homeassistant/helpers/translation.py b/homeassistant/helpers/translation.py index cda50de535b..616baeeea92 100644 --- a/homeassistant/helpers/translation.py +++ b/homeassistant/helpers/translation.py @@ -9,13 +9,11 @@ from typing import Any from homeassistant.core import HomeAssistant, callback from homeassistant.loader import ( - MAX_LOAD_CONCURRENTLY, Integration, async_get_config_flows, - async_get_integration, + async_get_integrations, bind_hass, ) -from homeassistant.util.async_ import gather_with_concurrency from homeassistant.util.json import load_json _LOGGER = logging.getLogger(__name__) @@ -151,16 +149,13 @@ async def async_get_component_strings( ) -> dict[str, Any]: """Load translations.""" domains = list({loaded.split(".")[-1] for loaded in components}) - integrations = dict( - zip( - domains, - await gather_with_concurrency( - MAX_LOAD_CONCURRENTLY, - *(async_get_integration(hass, domain) for domain in domains), - ), - ) - ) + integrations: dict[str, Integration] = {} + ints_or_excs = await async_get_integrations(hass, domains) + for domain, int_or_exc in ints_or_excs.items(): + if isinstance(int_or_exc, Exception): + raise int_or_exc + integrations[domain] = int_or_exc translations: dict[str, Any] = {} # Determine paths of missing components/platforms diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 0a65928701b..d0e6189ef96 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -7,7 +7,7 @@ documentation as possible to keep it understandable. from __future__ import annotations import asyncio -from collections.abc import Callable +from collections.abc import Callable, Iterable from contextlib import suppress import functools as ft import importlib @@ -31,7 +31,6 @@ from .generated.ssdp import SSDP from .generated.usb import USB from .generated.zeroconf import HOMEKIT, ZEROCONF from .helpers.json import JSON_DECODE_EXCEPTIONS, json_loads -from .util.async_ import gather_with_concurrency # Typing imports that create a circular dependency if TYPE_CHECKING: @@ -128,6 +127,7 @@ class Manifest(TypedDict, total=False): version: str codeowners: list[str] loggers: list[str] + supported_brands: dict[str, str] def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest: @@ -166,19 +166,15 @@ async def _async_get_custom_components( get_sub_directories, custom_components.__path__ ) - integrations = await gather_with_concurrency( - MAX_LOAD_CONCURRENTLY, - *( - hass.async_add_executor_job( - Integration.resolve_from_root, hass, custom_components, comp.name - ) - for comp in dirs - ), + integrations = await hass.async_add_executor_job( + _resolve_integrations_from_root, + hass, + custom_components, + [comp.name for comp in dirs], ) - return { integration.domain: integration - for integration in integrations + for integration in integrations.values() if integration is not None } @@ -681,59 +677,101 @@ class Integration: return f"" +def _resolve_integrations_from_root( + hass: HomeAssistant, root_module: ModuleType, domains: list[str] +) -> dict[str, Integration]: + """Resolve multiple integrations from root.""" + integrations: dict[str, Integration] = {} + for domain in domains: + try: + integration = Integration.resolve_from_root(hass, root_module, domain) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error loading integration: %s", domain) + else: + if integration: + integrations[domain] = integration + return integrations + + async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration: - """Get an integration.""" + """Get integration.""" + integrations_or_excs = await async_get_integrations(hass, [domain]) + int_or_exc = integrations_or_excs[domain] + if isinstance(int_or_exc, Integration): + return int_or_exc + raise int_or_exc + + +async def async_get_integrations( + hass: HomeAssistant, domains: Iterable[str] +) -> dict[str, Integration | Exception]: + """Get integrations.""" if (cache := hass.data.get(DATA_INTEGRATIONS)) is None: if not _async_mount_config_dir(hass): - raise IntegrationNotFound(domain) + return {domain: IntegrationNotFound(domain) for domain in domains} cache = hass.data[DATA_INTEGRATIONS] = {} - int_or_evt: Integration | asyncio.Event | None = cache.get(domain, _UNDEF) + results: dict[str, Integration | Exception] = {} + needed: dict[str, asyncio.Event] = {} + in_progress: dict[str, asyncio.Event] = {} + for domain in domains: + int_or_evt: Integration | asyncio.Event | None = cache.get(domain, _UNDEF) + if isinstance(int_or_evt, asyncio.Event): + in_progress[domain] = int_or_evt + elif int_or_evt is not _UNDEF: + results[domain] = cast(Integration, int_or_evt) + elif "." in domain: + results[domain] = ValueError(f"Invalid domain {domain}") + else: + needed[domain] = cache[domain] = asyncio.Event() - if isinstance(int_or_evt, asyncio.Event): - await int_or_evt.wait() + if in_progress: + await asyncio.gather(*[event.wait() for event in in_progress.values()]) + for domain in in_progress: + # When we have waited and it's _UNDEF, it doesn't exist + # We don't cache that it doesn't exist, or else people can't fix it + # and then restart, because their config will never be valid. + if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF: + results[domain] = IntegrationNotFound(domain) + else: + results[domain] = cast(Integration, int_or_evt) - # When we have waited and it's _UNDEF, it doesn't exist - # We don't cache that it doesn't exist, or else people can't fix it - # and then restart, because their config will never be valid. - if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF: - raise IntegrationNotFound(domain) + # First we look for custom components + if needed: + # Instead of using resolve_from_root we use the cache of custom + # components to find the integration. + custom = await async_get_custom_components(hass) + for domain, event in needed.items(): + if integration := custom.get(domain): + results[domain] = cache[domain] = integration + event.set() - if int_or_evt is not _UNDEF: - return cast(Integration, int_or_evt) + for domain in results: + if domain in needed: + del needed[domain] - event = cache[domain] = asyncio.Event() + # Now the rest use resolve_from_root + if needed: + from . import components # pylint: disable=import-outside-toplevel - try: - integration = await _async_get_integration(hass, domain) - except Exception: - # Remove event from cache. - cache.pop(domain) - event.set() - raise + integrations = await hass.async_add_executor_job( + _resolve_integrations_from_root, hass, components, list(needed) + ) + for domain, event in needed.items(): + int_or_exc = integrations.get(domain) + if not int_or_exc: + cache.pop(domain) + results[domain] = IntegrationNotFound(domain) + elif isinstance(int_or_exc, Exception): + cache.pop(domain) + exc = IntegrationNotFound(domain) + exc.__cause__ = int_or_exc + results[domain] = exc + else: + results[domain] = cache[domain] = int_or_exc + event.set() - cache[domain] = integration - event.set() - return integration - - -async def _async_get_integration(hass: HomeAssistant, domain: str) -> Integration: - if "." in domain: - raise ValueError(f"Invalid domain {domain}") - - # Instead of using resolve_from_root we use the cache of custom - # components to find the integration. - if integration := (await async_get_custom_components(hass)).get(domain): - return integration - - from . import components # pylint: disable=import-outside-toplevel - - if integration := await hass.async_add_executor_job( - Integration.resolve_from_root, hass, components, domain - ): - return integration - - raise IntegrationNotFound(domain) + return results class LoaderError(Exception): diff --git a/tests/components/analytics/test_analytics.py b/tests/components/analytics/test_analytics.py index a18c59f171f..82a61126432 100644 --- a/tests/components/analytics/test_analytics.py +++ b/tests/components/analytics/test_analytics.py @@ -269,8 +269,8 @@ async def test_send_statistics_one_integration_fails(hass, caplog, aioclient_moc hass.config.components = ["default_config"] with patch( - "homeassistant.components.analytics.analytics.async_get_integration", - side_effect=IntegrationNotFound("any"), + "homeassistant.components.analytics.analytics.async_get_integrations", + return_value={"any": IntegrationNotFound("any")}, ), patch("homeassistant.components.analytics.analytics.HA_VERSION", MOCK_VERSION): await analytics.send_analytics() @@ -291,8 +291,8 @@ async def test_send_statistics_async_get_integration_unknown_exception( hass.config.components = ["default_config"] with pytest.raises(ValueError), patch( - "homeassistant.components.analytics.analytics.async_get_integration", - side_effect=ValueError, + "homeassistant.components.analytics.analytics.async_get_integrations", + return_value={"any": ValueError()}, ), patch("homeassistant.components.analytics.analytics.HA_VERSION", MOCK_VERSION): await analytics.send_analytics() diff --git a/tests/helpers/test_translation.py b/tests/helpers/test_translation.py index 2e30a649a7b..d993233ac5d 100644 --- a/tests/helpers/test_translation.py +++ b/tests/helpers/test_translation.py @@ -135,8 +135,8 @@ async def test_get_translations_loads_config_flows(hass, mock_config_flows): "homeassistant.helpers.translation.load_translations_files", return_value={"component1": {"title": "world"}}, ), patch( - "homeassistant.helpers.translation.async_get_integration", - return_value=integration, + "homeassistant.helpers.translation.async_get_integrations", + return_value={"component1": integration}, ): translations = await translation.async_get_translations( hass, "en", "title", config_flow=True @@ -164,8 +164,8 @@ async def test_get_translations_loads_config_flows(hass, mock_config_flows): "homeassistant.helpers.translation.load_translations_files", return_value={"component2": {"title": "world"}}, ), patch( - "homeassistant.helpers.translation.async_get_integration", - return_value=integration, + "homeassistant.helpers.translation.async_get_integrations", + return_value={"component2": integration}, ): translations = await translation.async_get_translations( hass, "en", "title", config_flow=True @@ -212,8 +212,8 @@ async def test_get_translations_while_loading_components(hass): "homeassistant.helpers.translation.load_translations_files", mock_load_translation_files, ), patch( - "homeassistant.helpers.translation.async_get_integration", - return_value=integration, + "homeassistant.helpers.translation.async_get_integrations", + return_value={"component1": integration}, ): tasks = [ translation.async_get_translations(hass, "en", "title") for _ in range(5)