Add support for preloading platforms in the loader (#112282)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2024-03-04 16:33:12 -10:00 committed by GitHub
parent d0c81f7d00
commit 1e173e82d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 163 additions and 97 deletions

View File

@ -1459,9 +1459,9 @@ async def async_process_component_config( # noqa: C901
# Check if the integration has a custom config validator # Check if the integration has a custom config validator
config_validator = None config_validator = None
# A successful call to async_get_component will prime # A successful call to async_get_component will prime
# the cache for platform_exists to ensure it does no # the cache for platforms_exists to ensure it does no
# blocking I/O # blocking I/O
if integration.platform_exists("config") is not False: if integration.platforms_exists(("config",)):
# If the config platform cannot possibly exist, don't try to load it. # If the config platform cannot possibly exist, don't try to load it.
try: try:
config_validator = await integration.async_get_platform("config") config_validator = await integration.async_get_platform("config")

View File

@ -15,6 +15,7 @@ from homeassistant.loader import (
Integration, Integration,
async_get_integrations, async_get_integrations,
async_get_loaded_integration, async_get_loaded_integration,
async_register_preload_platform,
bind_hass, bind_hass,
) )
from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded
@ -58,7 +59,7 @@ def _get_platform(
# `stat()` system calls which is far cheaper than calling # `stat()` system calls which is far cheaper than calling
# `integration.get_platform` # `integration.get_platform`
# #
if integration.platform_exists(platform_name) is False: if not integration.platforms_exists((platform_name,)):
# If the platform cannot possibly exist, don't bother trying to load it # If the platform cannot possibly exist, don't bother trying to load it
return None return None
@ -127,6 +128,7 @@ async def async_process_integration_platforms(
else: else:
integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS] integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS]
async_register_preload_platform(hass, platform_name)
top_level_components = {comp for comp in hass.config.components if "." not in comp} top_level_components = {comp for comp in hass.config.components if "." not in comp}
process_job = HassJob( process_job = HassJob(
catch_log_exception( catch_log_exception(

View File

@ -54,10 +54,38 @@ _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
#
# Integration.get_component will check preload platforms and
# try to import the code to avoid a thundering heard of import
# executor jobs later in the startup process.
#
# default platforms are prepopulated in this list to ensure that
# by the time the component is loaded, we check if the platform is
# available.
#
# This list can be extended by calling async_register_preload_platform
#
BASE_PRELOAD_PLATFORMS = [
"config",
"diagnostics",
"energy",
"group",
"logbook",
"hardware",
"intent",
"media_source",
"recorder",
"repairs",
"system_health",
"trigger",
]
DATA_COMPONENTS = "components" DATA_COMPONENTS = "components"
DATA_INTEGRATIONS = "integrations" DATA_INTEGRATIONS = "integrations"
DATA_MISSING_PLATFORMS = "missing_platforms" DATA_MISSING_PLATFORMS = "missing_platforms"
DATA_CUSTOM_COMPONENTS = "custom_components" DATA_CUSTOM_COMPONENTS = "custom_components"
DATA_PRELOAD_PLATFORMS = "preload_platforms"
PACKAGE_CUSTOM_COMPONENTS = "custom_components" PACKAGE_CUSTOM_COMPONENTS = "custom_components"
PACKAGE_BUILTIN = "homeassistant.components" PACKAGE_BUILTIN = "homeassistant.components"
CUSTOM_WARNING = ( CUSTOM_WARNING = (
@ -161,7 +189,7 @@ class Manifest(TypedDict, total=False):
disabled: str disabled: str
domain: str domain: str
integration_type: Literal[ integration_type: Literal[
"entity", "device", "hardware", "helper", "hub", "service", "system" "entity", "device", "hardware", "helper", "hub", "service", "system", "virtual"
] ]
dependencies: list[str] dependencies: list[str]
after_dependencies: list[str] after_dependencies: list[str]
@ -192,6 +220,7 @@ def async_setup(hass: HomeAssistant) -> None:
hass.data[DATA_COMPONENTS] = {} hass.data[DATA_COMPONENTS] = {}
hass.data[DATA_INTEGRATIONS] = {} hass.data[DATA_INTEGRATIONS] = {}
hass.data[DATA_MISSING_PLATFORMS] = {} hass.data[DATA_MISSING_PLATFORMS] = {}
hass.data[DATA_PRELOAD_PLATFORMS] = BASE_PRELOAD_PLATFORMS.copy()
def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest: def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest:
@ -568,6 +597,14 @@ async def async_get_mqtt(hass: HomeAssistant) -> dict[str, list[str]]:
return mqtt return mqtt
@callback
def async_register_preload_platform(hass: HomeAssistant, platform_name: str) -> None:
"""Register a platform to be preloaded."""
preload_platforms: list[str] = hass.data[DATA_PRELOAD_PLATFORMS]
if platform_name not in preload_platforms:
preload_platforms.append(platform_name)
class Integration: class Integration:
"""An integration in Home Assistant.""" """An integration in Home Assistant."""
@ -590,11 +627,16 @@ class Integration:
) )
continue continue
file_path = manifest_path.parent
# Avoid the listdir for virtual integrations
# as they cannot have any platforms
is_virtual = manifest.get("integration_type") == "virtual"
integration = cls( integration = cls(
hass, hass,
f"{root_module.__name__}.{domain}", f"{root_module.__name__}.{domain}",
manifest_path.parent, file_path,
manifest, manifest,
None if is_virtual else set(os.listdir(file_path)),
) )
if integration.is_built_in: if integration.is_built_in:
@ -647,6 +689,7 @@ class Integration:
pkg_path: str, pkg_path: str,
file_path: pathlib.Path, file_path: pathlib.Path,
manifest: Manifest, manifest: Manifest,
top_level_files: set[str] | None = None,
) -> None: ) -> None:
"""Initialize an integration.""" """Initialize an integration."""
self.hass = hass self.hass = hass
@ -662,6 +705,8 @@ class Integration:
self._all_dependencies_resolved = True self._all_dependencies_resolved = True
self._all_dependencies = set() self._all_dependencies = set()
platforms_to_preload: list[str] = hass.data[DATA_PRELOAD_PLATFORMS]
self._platforms_to_preload = platforms_to_preload
self._component_future: asyncio.Future[ComponentProtocol] | None = None self._component_future: asyncio.Future[ComponentProtocol] | None = None
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {} self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS] cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS]
@ -670,6 +715,7 @@ class Integration:
DATA_MISSING_PLATFORMS DATA_MISSING_PLATFORMS
] ]
self._missing_platforms_cache = missing_platforms_cache self._missing_platforms_cache = missing_platforms_cache
self._top_level_files = top_level_files or set()
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path) _LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
@cached_property @cached_property
@ -735,7 +781,9 @@ class Integration:
@cached_property @cached_property
def integration_type( def integration_type(
self, self,
) -> Literal["entity", "device", "hardware", "helper", "hub", "service", "system"]: ) -> Literal[
"entity", "device", "hardware", "helper", "hub", "service", "system", "virtual"
]:
"""Return the integration type.""" """Return the integration type."""
return self.manifest.get("integration_type", "hub") return self.manifest.get("integration_type", "hub")
@ -878,7 +926,9 @@ class Integration:
self._component_future = self.hass.loop.create_future() self._component_future = self.hass.loop.create_future()
try: try:
try: try:
comp = await self.hass.async_add_import_executor_job(self.get_component) comp = await self.hass.async_add_import_executor_job(
self.get_component, True
)
except ImportError as ex: except ImportError as ex:
load_executor = False load_executor = False
_LOGGER.debug( _LOGGER.debug(
@ -909,7 +959,7 @@ class Integration:
return comp return comp
def get_component(self) -> ComponentProtocol: def get_component(self, preload_platforms: bool = False) -> ComponentProtocol:
"""Return the component. """Return the component.
This method must be thread-safe as it's called from the executor This method must be thread-safe as it's called from the executor
@ -940,22 +990,19 @@ class Integration:
) )
raise ImportError(f"Exception importing {self.pkg_path}") from err raise ImportError(f"Exception importing {self.pkg_path}") from err
if self.platform_exists("config"): if preload_platforms:
# Setting up a component always checks if the config for platform_name in self.platforms_exists(self._platforms_to_preload):
# platform exists. Since we may be running in the executor with suppress(ImportError):
# we will use this opportunity to cache the config platform self.get_platform(platform_name)
# as well.
with suppress(ImportError):
self.get_platform("config")
if self.config_flow: if self.config_flow:
# If there is a config flow, we will cache it as well since # If there is a config flow, we will cache it as well since
# config entry setup always has to load the flow to get the # config entry setup always has to load the flow to get the
# major/minor version for migrations. Since we may be running # major/minor version for migrations. Since we may be running
# in the executor we will use this opportunity to cache the # in the executor we will use this opportunity to cache the
# config_flow as well. # config_flow as well.
with suppress(ImportError): with suppress(ImportError):
self.get_platform("config_flow") self.get_platform("config_flow")
return cache[domain] return cache[domain]
@ -985,7 +1032,7 @@ class Integration:
for platform_name in platform_names: for platform_name in platform_names:
full_name = f"{domain}.{platform_name}" full_name = f"{domain}.{platform_name}"
if platform := self._get_platform_cached(full_name): if platform := self._get_platform_cached_or_raise(full_name):
platforms[platform_name] = platform platforms[platform_name] = platform
continue continue
@ -1065,7 +1112,7 @@ class Integration:
return platforms return platforms
def _get_platform_cached(self, full_name: str) -> ModuleType | None: def _get_platform_cached_or_raise(self, full_name: str) -> ModuleType | None:
"""Return a platform for an integration from cache.""" """Return a platform for an integration from cache."""
if full_name in self._cache: if full_name in self._cache:
# the cache is either a ModuleType or a ComponentProtocol # the cache is either a ModuleType or a ComponentProtocol
@ -1077,43 +1124,35 @@ class Integration:
def get_platform(self, platform_name: str) -> ModuleType: def get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration.""" """Return a platform for an integration."""
if platform := self._get_platform_cached(f"{self.domain}.{platform_name}"): if platform := self._get_platform_cached_or_raise(
f"{self.domain}.{platform_name}"
):
return platform return platform
return self._load_platform(platform_name) return self._load_platform(platform_name)
def platform_exists(self, platform_name: str) -> bool | None: def platforms_exists(self, platform_names: Iterable[str]) -> list[str]:
"""Check if a platform exists for an integration. """Check if a platforms exists for an integration.
Returns True if the platform exists, False if it does not. This method is thread-safe and can be called from the executor
or event loop without doing blocking I/O.
If it cannot be determined if the platform exists without attempting
to import the component, it returns None. This will only happen
if this function is called before get_component or async_get_component
has been called for the integration or the integration failed to load.
""" """
full_name = f"{self.domain}.{platform_name}" files = self._top_level_files
cache = self._cache domain = self.domain
if full_name in cache: existing_platforms: list[str] = []
return True missing_platforms = self._missing_platforms_cache
for platform_name in platform_names:
full_name = f"{domain}.{platform_name}"
if full_name not in missing_platforms and (
f"{platform_name}.py" in files or platform_name in files
):
existing_platforms.append(platform_name)
continue
missing_platforms[full_name] = ModuleNotFoundError(
f"Platform {full_name} not found",
name=f"{self.pkg_path}.{platform_name}",
)
if full_name in self._missing_platforms_cache: return existing_platforms
return False
if not (component := cache.get(self.domain)) or not (
file := getattr(component, "__file__", None)
):
return None
path: pathlib.Path = pathlib.Path(file).parent.joinpath(platform_name)
if os.path.exists(path.with_suffix(".py")) or os.path.exists(path):
return True
exc = ModuleNotFoundError(
f"Platform {full_name} not found",
name=f"{self.pkg_path}.{platform_name}",
)
self._missing_platforms_cache[full_name] = exc
return False
def _load_platform(self, platform_name: str) -> ModuleType: def _load_platform(self, platform_name: str) -> ModuleType:
"""Load a platform for an integration. """Load a platform for an integration.

View File

@ -1396,6 +1396,7 @@ def mock_integration(
else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}", else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}",
pathlib.Path(""), pathlib.Path(""),
module.mock_manifest(), module.mock_manifest(),
set(),
) )
def mock_import_platform(platform_name: str) -> NoReturn: def mock_import_platform(platform_name: str) -> NoReturn:
@ -1423,13 +1424,14 @@ def mock_platform(
platform_path is in form hue.config_flow. platform_path is in form hue.config_flow.
""" """
domain = platform_path.split(".")[0] domain, _, platform_name = platform_path.partition(".")
integration_cache = hass.data[loader.DATA_INTEGRATIONS] integration_cache = hass.data[loader.DATA_INTEGRATIONS]
module_cache = hass.data[loader.DATA_COMPONENTS] module_cache = hass.data[loader.DATA_COMPONENTS]
if domain not in integration_cache: if domain not in integration_cache:
mock_integration(hass, MockModule(domain)) mock_integration(hass, MockModule(domain))
integration_cache[domain]._top_level_files.add(f"{platform_name}.py")
_LOGGER.info("Adding mock integration platform: %s", platform_path) _LOGGER.info("Adding mock integration platform: %s", platform_path)
module_cache[platform_path] = module or Mock() module_cache[platform_path] = module or Mock()

View File

@ -1094,27 +1094,36 @@ async def test_async_get_component_preloads_config_and_config_flow(
assert "homeassistant.components.executor_import" not in sys.modules assert "homeassistant.components.executor_import" not in sys.modules
assert "custom_components.executor_import" not in sys.modules assert "custom_components.executor_import" not in sys.modules
platform_exists_calls = []
def mock_platforms_exists(platforms: list[str]) -> bool:
platform_exists_calls.append(platforms)
return platforms
with patch( with patch(
"homeassistant.loader.importlib.import_module" "homeassistant.loader.importlib.import_module"
) as mock_import, patch.object( ) as mock_import, patch.object(
executor_import_integration, "platform_exists", return_value=True executor_import_integration, "platforms_exists", mock_platforms_exists
) as mock_platform_exists: ):
await executor_import_integration.async_get_component() await executor_import_integration.async_get_component()
assert mock_platform_exists.call_count == 1 assert len(platform_exists_calls[0]) == len(loader.BASE_PRELOAD_PLATFORMS)
assert mock_import.call_count == 3 assert mock_import.call_count == 2 + len(loader.BASE_PRELOAD_PLATFORMS)
assert ( assert (
mock_import.call_args_list[0][0][0] mock_import.call_args_list[0][0][0]
== "homeassistant.components.executor_import" == "homeassistant.components.executor_import"
) )
assert ( checked_platforms = {
mock_import.call_args_list[1][0][0] mock_import.call_args_list[i][0][0]
== "homeassistant.components.executor_import.config" for i in range(1, len(mock_import.call_args_list))
) }
assert ( assert checked_platforms == {
mock_import.call_args_list[2][0][0] "homeassistant.components.executor_import.config_flow",
== "homeassistant.components.executor_import.config_flow" *(
) f"homeassistant.components.executor_import.{platform}"
for platform in loader.BASE_PRELOAD_PLATFORMS
),
}
async def test_async_get_component_loads_loop_if_already_in_sys_modules( async def test_async_get_component_loads_loop_if_already_in_sys_modules(
@ -1134,8 +1143,8 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules(
assert "test_package_loaded_executor.config_flow" not in hass.config.components assert "test_package_loaded_executor.config_flow" not in hass.config.components
config_flow_module_name = f"{integration.pkg_path}.config_flow" config_flow_module_name = f"{integration.pkg_path}.config_flow"
module_mock = MagicMock() module_mock = MagicMock(__file__="__init__.py")
config_flow_module_mock = MagicMock() config_flow_module_mock = MagicMock(__file__="config_flow.py")
def import_module(name: str) -> Any: def import_module(name: str) -> Any:
if name == integration.pkg_path: if name == integration.pkg_path:
@ -1194,8 +1203,8 @@ async def test_async_get_component_concurrent_loads(
assert "test_package_loaded_executor.config_flow" not in hass.config.components assert "test_package_loaded_executor.config_flow" not in hass.config.components
config_flow_module_name = f"{integration.pkg_path}.config_flow" config_flow_module_name = f"{integration.pkg_path}.config_flow"
module_mock = MagicMock() module_mock = MagicMock(__file__="__init__.py")
config_flow_module_mock = MagicMock() config_flow_module_mock = MagicMock(__file__="config_flow.py")
imports = [] imports = []
start_event = threading.Event() start_event = threading.Event()
import_event = asyncio.Event() import_event = asyncio.Event()
@ -1232,7 +1241,8 @@ async def test_async_get_component_concurrent_loads(
assert comp1 is module_mock assert comp1 is module_mock
assert comp2 is module_mock assert comp2 is module_mock
assert imports == [integration.pkg_path, config_flow_module_name] assert integration.pkg_path in imports
assert config_flow_module_name in imports
async def test_async_get_component_deadlock_fallback( async def test_async_get_component_deadlock_fallback(
@ -1243,7 +1253,7 @@ async def test_async_get_component_deadlock_fallback(
hass, "executor_import", True, import_executor=True hass, "executor_import", True, import_executor=True
) )
assert executor_import_integration.import_executor is True assert executor_import_integration.import_executor is True
module_mock = MagicMock() module_mock = MagicMock(__file__="__init__.py")
import_attempts = 0 import_attempts = 0
def mock_import(module: str, *args: Any, **kwargs: Any) -> Any: def mock_import(module: str, *args: Any, **kwargs: Any) -> Any:
@ -1395,38 +1405,51 @@ async def test_async_get_platform_raises_after_import_failure(
assert "loaded_executor=False" not in caplog.text assert "loaded_executor=False" not in caplog.text
async def test_platform_exists( async def test_platforms_exists(
hass: HomeAssistant, enable_custom_integrations: None hass: HomeAssistant, enable_custom_integrations: None
) -> None: ) -> None:
"""Test platform_exists.""" """Test platforms_exists."""
integration = await loader.async_get_integration(hass, "test_integration_platform") original_os_listdir = os.listdir
assert integration.domain == "test_integration_platform"
# get_component never called, will return None paths: list[str] = []
assert integration.platform_exists("non_existing") is None
component = integration.get_component() def mock_list_dir(path: str) -> list[str]:
paths.append(path)
return original_os_listdir(path)
with patch("homeassistant.loader.os.listdir", mock_list_dir):
integration = await loader.async_get_integration(
hass, "test_integration_platform"
)
assert integration.domain == "test_integration_platform"
# Verify the files cache is primed
assert integration.file_path in paths
# component is loaded, should now return False
with patch("homeassistant.loader.os.listdir", wraps=os.listdir) as mock_exists:
component = integration.get_component()
assert component.DOMAIN == "test_integration_platform" assert component.DOMAIN == "test_integration_platform"
# component is loaded, should now return False # The files cache should be primed when
with patch( # the integration is resolved
"homeassistant.loader.os.path.exists", wraps=os.path.exists assert mock_exists.call_count == 0
) as mock_exists:
assert integration.platform_exists("non_existing") is False
# We should check if the file exists
assert mock_exists.call_count == 2
# component is loaded, should now return False # component is loaded, should now return False
with patch( with patch("homeassistant.loader.os.listdir", wraps=os.listdir) as mock_exists:
"homeassistant.loader.os.path.exists", wraps=os.path.exists assert integration.platforms_exists(("non_existing",)) == []
) as mock_exists:
assert integration.platform_exists("non_existing") is False # We should remember which files exist
assert mock_exists.call_count == 0
# component is loaded, should now return False
with patch("homeassistant.loader.os.listdir", wraps=os.listdir) as mock_exists:
assert integration.platforms_exists(("non_existing",)) == []
# We should remember the file does not exist # We should remember the file does not exist
assert mock_exists.call_count == 0 assert mock_exists.call_count == 0
assert integration.platform_exists("group") is True assert integration.platforms_exists(["group"]) == ["group"]
platform = await integration.async_get_platform("group") platform = await integration.async_get_platform("group")
assert platform.MAGIC == 1 assert platform.MAGIC == 1
@ -1434,7 +1457,7 @@ async def test_platform_exists(
platform = integration.get_platform("group") platform = integration.get_platform("group")
assert platform.MAGIC == 1 assert platform.MAGIC == 1
assert integration.platform_exists("group") is True assert integration.platforms_exists(["group"]) == ["group"]
async def test_async_get_platforms_loads_loop_if_already_in_sys_modules( async def test_async_get_platforms_loads_loop_if_already_in_sys_modules(

View File

@ -843,7 +843,7 @@ async def test_async_prepare_setup_platform(
caplog.clear() caplog.clear()
# There is no actual config platform for this integration # There is no actual config platform for this integration
assert await setup.async_prepare_setup_platform(hass, {}, "config", "test") is None assert await setup.async_prepare_setup_platform(hass, {}, "config", "test") is None
assert "test.config not found" in caplog.text assert "No module named 'custom_components.test.config'" in caplog.text
button_platform = ( button_platform = (
await setup.async_prepare_setup_platform(hass, {}, "button", "test") is None await setup.async_prepare_setup_platform(hass, {}, "button", "test") is None