Consolidate executor jobs when loading integration manifests (#75176)

This commit is contained in:
J. Nick Koston 2022-07-14 22:06:08 +02:00 committed by GitHub
parent fef1b842ce
commit 61cc9f5288
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 143 additions and 118 deletions

View File

@ -35,7 +35,6 @@ from .setup import (
async_setup_component, async_setup_component,
) )
from .util import dt as dt_util from .util import dt as dt_util
from .util.async_ import gather_with_concurrency
from .util.logging import async_activate_log_queue_handler from .util.logging import async_activate_log_queue_handler
from .util.package import async_get_user_site, is_virtual_env from .util.package import async_get_user_site, is_virtual_env
@ -479,14 +478,9 @@ async def _async_set_up_integrations(
integrations_to_process = [ integrations_to_process = [
int_or_exc int_or_exc
for int_or_exc in await gather_with_concurrency( for int_or_exc in (
loader.MAX_LOAD_CONCURRENTLY, await loader.async_get_integrations(hass, old_to_resolve)
*( ).values()
loader.async_get_integration(hass, domain)
for domain in old_to_resolve
),
return_exceptions=True,
)
if isinstance(int_or_exc, loader.Integration) if isinstance(int_or_exc, loader.Integration)
] ]
resolve_dependencies_tasks = [ resolve_dependencies_tasks = [

View File

@ -18,7 +18,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.helpers.system_info import async_get_system_info from homeassistant.helpers.system_info import async_get_system_info
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integrations
from homeassistant.setup import async_get_loaded_integrations from homeassistant.setup import async_get_loaded_integrations
from .const import ( from .const import (
@ -182,15 +182,9 @@ class Analytics:
if self.preferences.get(ATTR_USAGE, False) or self.preferences.get( if self.preferences.get(ATTR_USAGE, False) or self.preferences.get(
ATTR_STATISTICS, False ATTR_STATISTICS, False
): ):
configured_integrations = await asyncio.gather( domains = async_get_loaded_integrations(self.hass)
*( configured_integrations = await async_get_integrations(self.hass, domains)
async_get_integration(self.hass, domain) for integration in configured_integrations.values():
for domain in async_get_loaded_integrations(self.hass)
),
return_exceptions=True,
)
for integration in configured_integrations:
if isinstance(integration, IntegrationNotFound): if isinstance(integration, IntegrationNotFound):
continue continue

View File

@ -1,7 +1,6 @@
"""Commands part of Websocket API.""" """Commands part of Websocket API."""
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
import datetime as dt import datetime as dt
import json import json
@ -32,7 +31,12 @@ from homeassistant.helpers.event import (
) )
from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import (
Integration,
IntegrationNotFound,
async_get_integration,
async_get_integrations,
)
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
from homeassistant.util.json import ( from homeassistant.util.json import (
find_paths_unserializable_data, find_paths_unserializable_data,
@ -372,9 +376,13 @@ async def handle_manifest_list(
wanted_integrations = msg.get("integrations") wanted_integrations = msg.get("integrations")
if wanted_integrations is None: if wanted_integrations is None:
wanted_integrations = async_get_loaded_integrations(hass) wanted_integrations = async_get_loaded_integrations(hass)
integrations = await asyncio.gather(
*(async_get_integration(hass, domain) for domain in wanted_integrations) ints_or_excs = await async_get_integrations(hass, wanted_integrations)
) integrations: list[Integration] = []
for int_or_exc in ints_or_excs.values():
if isinstance(int_or_exc, Exception):
raise int_or_exc
integrations.append(int_or_exc)
connection.send_result( connection.send_result(
msg["id"], [integration.manifest for integration in integrations] msg["id"], [integration.manifest for integration in integrations]
) )
@ -706,12 +714,12 @@ async def handle_supported_brands(
) -> None: ) -> None:
"""Handle supported brands command.""" """Handle supported brands command."""
data = {} data = {}
for integration in await asyncio.gather(
*[
async_get_integration(hass, integration)
for integration in supported_brands.HAS_SUPPORTED_BRANDS
]
):
data[integration.domain] = integration.manifest["supported_brands"]
ints_or_excs = await async_get_integrations(
hass, supported_brands.HAS_SUPPORTED_BRANDS
)
for int_or_exc in ints_or_excs.values():
if isinstance(int_or_exc, Exception):
raise int_or_exc
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
connection.send_result(msg["id"], data) connection.send_result(msg["id"], data)

View File

@ -31,13 +31,7 @@ from homeassistant.exceptions import (
Unauthorized, Unauthorized,
UnknownUser, UnknownUser,
) )
from homeassistant.loader import ( from homeassistant.loader import Integration, async_get_integrations, bind_hass
MAX_LOAD_CONCURRENTLY,
Integration,
async_get_integration,
bind_hass,
)
from homeassistant.util.async_ import gather_with_concurrency
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
from homeassistant.util.yaml.loader import JSON_TYPE from homeassistant.util.yaml.loader import JSON_TYPE
@ -467,10 +461,12 @@ async def async_get_all_descriptions(
loaded = {} loaded = {}
if missing: if missing:
integrations = await gather_with_concurrency( ints_or_excs = await async_get_integrations(hass, missing)
MAX_LOAD_CONCURRENTLY, integrations = [
*(async_get_integration(hass, domain) for domain in missing), int_or_exc
) for int_or_exc in ints_or_excs.values()
if isinstance(int_or_exc, Integration)
]
contents = await hass.async_add_executor_job( contents = await hass.async_add_executor_job(
_load_services_files, hass, integrations _load_services_files, hass, integrations

View File

@ -9,13 +9,11 @@ from typing import Any
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import ( from homeassistant.loader import (
MAX_LOAD_CONCURRENTLY,
Integration, Integration,
async_get_config_flows, async_get_config_flows,
async_get_integration, async_get_integrations,
bind_hass, bind_hass,
) )
from homeassistant.util.async_ import gather_with_concurrency
from homeassistant.util.json import load_json from homeassistant.util.json import load_json
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -151,16 +149,13 @@ async def async_get_component_strings(
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Load translations.""" """Load translations."""
domains = list({loaded.split(".")[-1] for loaded in components}) domains = list({loaded.split(".")[-1] for loaded in components})
integrations = dict(
zip(
domains,
await gather_with_concurrency(
MAX_LOAD_CONCURRENTLY,
*(async_get_integration(hass, domain) for domain in domains),
),
)
)
integrations: dict[str, Integration] = {}
ints_or_excs = await async_get_integrations(hass, domains)
for domain, int_or_exc in ints_or_excs.items():
if isinstance(int_or_exc, Exception):
raise int_or_exc
integrations[domain] = int_or_exc
translations: dict[str, Any] = {} translations: dict[str, Any] = {}
# Determine paths of missing components/platforms # Determine paths of missing components/platforms

View File

@ -7,7 +7,7 @@ documentation as possible to keep it understandable.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable, Iterable
from contextlib import suppress from contextlib import suppress
import functools as ft import functools as ft
import importlib import importlib
@ -31,7 +31,6 @@ from .generated.ssdp import SSDP
from .generated.usb import USB from .generated.usb import USB
from .generated.zeroconf import HOMEKIT, ZEROCONF from .generated.zeroconf import HOMEKIT, ZEROCONF
from .helpers.json import JSON_DECODE_EXCEPTIONS, json_loads from .helpers.json import JSON_DECODE_EXCEPTIONS, json_loads
from .util.async_ import gather_with_concurrency
# Typing imports that create a circular dependency # Typing imports that create a circular dependency
if TYPE_CHECKING: if TYPE_CHECKING:
@ -128,6 +127,7 @@ class Manifest(TypedDict, total=False):
version: str version: str
codeowners: list[str] codeowners: list[str]
loggers: list[str] loggers: list[str]
supported_brands: dict[str, str]
def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest: def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest:
@ -166,19 +166,15 @@ async def _async_get_custom_components(
get_sub_directories, custom_components.__path__ get_sub_directories, custom_components.__path__
) )
integrations = await gather_with_concurrency( integrations = await hass.async_add_executor_job(
MAX_LOAD_CONCURRENTLY, _resolve_integrations_from_root,
*( hass,
hass.async_add_executor_job( custom_components,
Integration.resolve_from_root, hass, custom_components, comp.name [comp.name for comp in dirs],
) )
for comp in dirs
),
)
return { return {
integration.domain: integration integration.domain: integration
for integration in integrations for integration in integrations.values()
if integration is not None if integration is not None
} }
@ -681,59 +677,101 @@ class Integration:
return f"<Integration {self.domain}: {self.pkg_path}>" return f"<Integration {self.domain}: {self.pkg_path}>"
def _resolve_integrations_from_root(
hass: HomeAssistant, root_module: ModuleType, domains: list[str]
) -> dict[str, Integration]:
"""Resolve multiple integrations from root."""
integrations: dict[str, Integration] = {}
for domain in domains:
try:
integration = Integration.resolve_from_root(hass, root_module, domain)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error loading integration: %s", domain)
else:
if integration:
integrations[domain] = integration
return integrations
async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration: async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration:
"""Get an integration.""" """Get integration."""
integrations_or_excs = await async_get_integrations(hass, [domain])
int_or_exc = integrations_or_excs[domain]
if isinstance(int_or_exc, Integration):
return int_or_exc
raise int_or_exc
async def async_get_integrations(
hass: HomeAssistant, domains: Iterable[str]
) -> dict[str, Integration | Exception]:
"""Get integrations."""
if (cache := hass.data.get(DATA_INTEGRATIONS)) is None: if (cache := hass.data.get(DATA_INTEGRATIONS)) is None:
if not _async_mount_config_dir(hass): if not _async_mount_config_dir(hass):
raise IntegrationNotFound(domain) return {domain: IntegrationNotFound(domain) for domain in domains}
cache = hass.data[DATA_INTEGRATIONS] = {} cache = hass.data[DATA_INTEGRATIONS] = {}
results: dict[str, Integration | Exception] = {}
needed: dict[str, asyncio.Event] = {}
in_progress: dict[str, asyncio.Event] = {}
for domain in domains:
int_or_evt: Integration | asyncio.Event | None = cache.get(domain, _UNDEF) int_or_evt: Integration | asyncio.Event | None = cache.get(domain, _UNDEF)
if isinstance(int_or_evt, asyncio.Event): if isinstance(int_or_evt, asyncio.Event):
await int_or_evt.wait() in_progress[domain] = int_or_evt
elif int_or_evt is not _UNDEF:
results[domain] = cast(Integration, int_or_evt)
elif "." in domain:
results[domain] = ValueError(f"Invalid domain {domain}")
else:
needed[domain] = cache[domain] = asyncio.Event()
if in_progress:
await asyncio.gather(*[event.wait() for event in in_progress.values()])
for domain in in_progress:
# When we have waited and it's _UNDEF, it doesn't exist # When we have waited and it's _UNDEF, it doesn't exist
# We don't cache that it doesn't exist, or else people can't fix it # We don't cache that it doesn't exist, or else people can't fix it
# and then restart, because their config will never be valid. # and then restart, because their config will never be valid.
if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF: if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF:
raise IntegrationNotFound(domain) results[domain] = IntegrationNotFound(domain)
else:
if int_or_evt is not _UNDEF: results[domain] = cast(Integration, int_or_evt)
return cast(Integration, int_or_evt)
event = cache[domain] = asyncio.Event()
try:
integration = await _async_get_integration(hass, domain)
except Exception:
# Remove event from cache.
cache.pop(domain)
event.set()
raise
cache[domain] = integration
event.set()
return integration
async def _async_get_integration(hass: HomeAssistant, domain: str) -> Integration:
if "." in domain:
raise ValueError(f"Invalid domain {domain}")
# First we look for custom components
if needed:
# Instead of using resolve_from_root we use the cache of custom # Instead of using resolve_from_root we use the cache of custom
# components to find the integration. # components to find the integration.
if integration := (await async_get_custom_components(hass)).get(domain): custom = await async_get_custom_components(hass)
return integration for domain, event in needed.items():
if integration := custom.get(domain):
results[domain] = cache[domain] = integration
event.set()
for domain in results:
if domain in needed:
del needed[domain]
# Now the rest use resolve_from_root
if needed:
from . import components # pylint: disable=import-outside-toplevel from . import components # pylint: disable=import-outside-toplevel
if integration := await hass.async_add_executor_job( integrations = await hass.async_add_executor_job(
Integration.resolve_from_root, hass, components, domain _resolve_integrations_from_root, hass, components, list(needed)
): )
return integration for domain, event in needed.items():
int_or_exc = integrations.get(domain)
if not int_or_exc:
cache.pop(domain)
results[domain] = IntegrationNotFound(domain)
elif isinstance(int_or_exc, Exception):
cache.pop(domain)
exc = IntegrationNotFound(domain)
exc.__cause__ = int_or_exc
results[domain] = exc
else:
results[domain] = cache[domain] = int_or_exc
event.set()
raise IntegrationNotFound(domain) return results
class LoaderError(Exception): class LoaderError(Exception):

View File

@ -269,8 +269,8 @@ async def test_send_statistics_one_integration_fails(hass, caplog, aioclient_moc
hass.config.components = ["default_config"] hass.config.components = ["default_config"]
with patch( with patch(
"homeassistant.components.analytics.analytics.async_get_integration", "homeassistant.components.analytics.analytics.async_get_integrations",
side_effect=IntegrationNotFound("any"), return_value={"any": IntegrationNotFound("any")},
), patch("homeassistant.components.analytics.analytics.HA_VERSION", MOCK_VERSION): ), patch("homeassistant.components.analytics.analytics.HA_VERSION", MOCK_VERSION):
await analytics.send_analytics() await analytics.send_analytics()
@ -291,8 +291,8 @@ async def test_send_statistics_async_get_integration_unknown_exception(
hass.config.components = ["default_config"] hass.config.components = ["default_config"]
with pytest.raises(ValueError), patch( with pytest.raises(ValueError), patch(
"homeassistant.components.analytics.analytics.async_get_integration", "homeassistant.components.analytics.analytics.async_get_integrations",
side_effect=ValueError, return_value={"any": ValueError()},
), patch("homeassistant.components.analytics.analytics.HA_VERSION", MOCK_VERSION): ), patch("homeassistant.components.analytics.analytics.HA_VERSION", MOCK_VERSION):
await analytics.send_analytics() await analytics.send_analytics()

View File

@ -135,8 +135,8 @@ async def test_get_translations_loads_config_flows(hass, mock_config_flows):
"homeassistant.helpers.translation.load_translations_files", "homeassistant.helpers.translation.load_translations_files",
return_value={"component1": {"title": "world"}}, return_value={"component1": {"title": "world"}},
), patch( ), patch(
"homeassistant.helpers.translation.async_get_integration", "homeassistant.helpers.translation.async_get_integrations",
return_value=integration, return_value={"component1": integration},
): ):
translations = await translation.async_get_translations( translations = await translation.async_get_translations(
hass, "en", "title", config_flow=True hass, "en", "title", config_flow=True
@ -164,8 +164,8 @@ async def test_get_translations_loads_config_flows(hass, mock_config_flows):
"homeassistant.helpers.translation.load_translations_files", "homeassistant.helpers.translation.load_translations_files",
return_value={"component2": {"title": "world"}}, return_value={"component2": {"title": "world"}},
), patch( ), patch(
"homeassistant.helpers.translation.async_get_integration", "homeassistant.helpers.translation.async_get_integrations",
return_value=integration, return_value={"component2": integration},
): ):
translations = await translation.async_get_translations( translations = await translation.async_get_translations(
hass, "en", "title", config_flow=True hass, "en", "title", config_flow=True
@ -212,8 +212,8 @@ async def test_get_translations_while_loading_components(hass):
"homeassistant.helpers.translation.load_translations_files", "homeassistant.helpers.translation.load_translations_files",
mock_load_translation_files, mock_load_translation_files,
), patch( ), patch(
"homeassistant.helpers.translation.async_get_integration", "homeassistant.helpers.translation.async_get_integrations",
return_value=integration, return_value={"component1": integration},
): ):
tasks = [ tasks = [
translation.async_get_translations(hass, "en", "title") for _ in range(5) translation.async_get_translations(hass, "en", "title") for _ in range(5)