Add support for preloading platforms in the loader (#112282)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2024-03-04 16:33:12 -10:00 committed by GitHub
parent d0c81f7d00
commit 1e173e82d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 163 additions and 97 deletions

View File

@ -1459,9 +1459,9 @@ async def async_process_component_config( # noqa: C901
# Check if the integration has a custom config validator
config_validator = None
# A successful call to async_get_component will prime
# the cache for platform_exists to ensure it does no
# the cache for platforms_exists to ensure it does no
# blocking I/O
if integration.platform_exists("config") is not False:
if integration.platforms_exists(("config",)):
# If the config platform cannot possibly exist, don't try to load it.
try:
config_validator = await integration.async_get_platform("config")

View File

@ -15,6 +15,7 @@ from homeassistant.loader import (
Integration,
async_get_integrations,
async_get_loaded_integration,
async_register_preload_platform,
bind_hass,
)
from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded
@ -58,7 +59,7 @@ def _get_platform(
# `stat()` system calls which is far cheaper than calling
# `integration.get_platform`
#
if integration.platform_exists(platform_name) is False:
if not integration.platforms_exists((platform_name,)):
# If the platform cannot possibly exist, don't bother trying to load it
return None
@ -127,6 +128,7 @@ async def async_process_integration_platforms(
else:
integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS]
async_register_preload_platform(hass, platform_name)
top_level_components = {comp for comp in hass.config.components if "." not in comp}
process_job = HassJob(
catch_log_exception(

View File

@ -54,10 +54,38 @@ _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
_LOGGER = logging.getLogger(__name__)
#
# Integration.get_component will check preload platforms and
# try to import the code to avoid a thundering heard of import
# executor jobs later in the startup process.
#
# default platforms are prepopulated in this list to ensure that
# by the time the component is loaded, we check if the platform is
# available.
#
# This list can be extended by calling async_register_preload_platform
#
BASE_PRELOAD_PLATFORMS = [
"config",
"diagnostics",
"energy",
"group",
"logbook",
"hardware",
"intent",
"media_source",
"recorder",
"repairs",
"system_health",
"trigger",
]
DATA_COMPONENTS = "components"
DATA_INTEGRATIONS = "integrations"
DATA_MISSING_PLATFORMS = "missing_platforms"
DATA_CUSTOM_COMPONENTS = "custom_components"
DATA_PRELOAD_PLATFORMS = "preload_platforms"
PACKAGE_CUSTOM_COMPONENTS = "custom_components"
PACKAGE_BUILTIN = "homeassistant.components"
CUSTOM_WARNING = (
@ -161,7 +189,7 @@ class Manifest(TypedDict, total=False):
disabled: str
domain: str
integration_type: Literal[
"entity", "device", "hardware", "helper", "hub", "service", "system"
"entity", "device", "hardware", "helper", "hub", "service", "system", "virtual"
]
dependencies: list[str]
after_dependencies: list[str]
@ -192,6 +220,7 @@ def async_setup(hass: HomeAssistant) -> None:
hass.data[DATA_COMPONENTS] = {}
hass.data[DATA_INTEGRATIONS] = {}
hass.data[DATA_MISSING_PLATFORMS] = {}
hass.data[DATA_PRELOAD_PLATFORMS] = BASE_PRELOAD_PLATFORMS.copy()
def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest:
@ -568,6 +597,14 @@ async def async_get_mqtt(hass: HomeAssistant) -> dict[str, list[str]]:
return mqtt
@callback
def async_register_preload_platform(hass: HomeAssistant, platform_name: str) -> None:
"""Register a platform to be preloaded."""
preload_platforms: list[str] = hass.data[DATA_PRELOAD_PLATFORMS]
if platform_name not in preload_platforms:
preload_platforms.append(platform_name)
class Integration:
"""An integration in Home Assistant."""
@ -590,11 +627,16 @@ class Integration:
)
continue
file_path = manifest_path.parent
# Avoid the listdir for virtual integrations
# as they cannot have any platforms
is_virtual = manifest.get("integration_type") == "virtual"
integration = cls(
hass,
f"{root_module.__name__}.{domain}",
manifest_path.parent,
file_path,
manifest,
None if is_virtual else set(os.listdir(file_path)),
)
if integration.is_built_in:
@ -647,6 +689,7 @@ class Integration:
pkg_path: str,
file_path: pathlib.Path,
manifest: Manifest,
top_level_files: set[str] | None = None,
) -> None:
"""Initialize an integration."""
self.hass = hass
@ -662,6 +705,8 @@ class Integration:
self._all_dependencies_resolved = True
self._all_dependencies = set()
platforms_to_preload: list[str] = hass.data[DATA_PRELOAD_PLATFORMS]
self._platforms_to_preload = platforms_to_preload
self._component_future: asyncio.Future[ComponentProtocol] | None = None
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS]
@ -670,6 +715,7 @@ class Integration:
DATA_MISSING_PLATFORMS
]
self._missing_platforms_cache = missing_platforms_cache
self._top_level_files = top_level_files or set()
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
@cached_property
@ -735,7 +781,9 @@ class Integration:
@cached_property
def integration_type(
self,
) -> Literal["entity", "device", "hardware", "helper", "hub", "service", "system"]:
) -> Literal[
"entity", "device", "hardware", "helper", "hub", "service", "system", "virtual"
]:
"""Return the integration type."""
return self.manifest.get("integration_type", "hub")
@ -878,7 +926,9 @@ class Integration:
self._component_future = self.hass.loop.create_future()
try:
try:
comp = await self.hass.async_add_import_executor_job(self.get_component)
comp = await self.hass.async_add_import_executor_job(
self.get_component, True
)
except ImportError as ex:
load_executor = False
_LOGGER.debug(
@ -909,7 +959,7 @@ class Integration:
return comp
def get_component(self) -> ComponentProtocol:
def get_component(self, preload_platforms: bool = False) -> ComponentProtocol:
"""Return the component.
This method must be thread-safe as it's called from the executor
@ -940,22 +990,19 @@ class Integration:
)
raise ImportError(f"Exception importing {self.pkg_path}") from err
if self.platform_exists("config"):
# Setting up a component always checks if the config
# platform exists. Since we may be running in the executor
# we will use this opportunity to cache the config platform
# as well.
with suppress(ImportError):
self.get_platform("config")
if preload_platforms:
for platform_name in self.platforms_exists(self._platforms_to_preload):
with suppress(ImportError):
self.get_platform(platform_name)
if self.config_flow:
# If there is a config flow, we will cache it as well since
# config entry setup always has to load the flow to get the
# major/minor version for migrations. Since we may be running
# in the executor we will use this opportunity to cache the
# config_flow as well.
with suppress(ImportError):
self.get_platform("config_flow")
if self.config_flow:
# If there is a config flow, we will cache it as well since
# config entry setup always has to load the flow to get the
# major/minor version for migrations. Since we may be running
# in the executor we will use this opportunity to cache the
# config_flow as well.
with suppress(ImportError):
self.get_platform("config_flow")
return cache[domain]
@ -985,7 +1032,7 @@ class Integration:
for platform_name in platform_names:
full_name = f"{domain}.{platform_name}"
if platform := self._get_platform_cached(full_name):
if platform := self._get_platform_cached_or_raise(full_name):
platforms[platform_name] = platform
continue
@ -1065,7 +1112,7 @@ class Integration:
return platforms
def _get_platform_cached(self, full_name: str) -> ModuleType | None:
def _get_platform_cached_or_raise(self, full_name: str) -> ModuleType | None:
"""Return a platform for an integration from cache."""
if full_name in self._cache:
# the cache is either a ModuleType or a ComponentProtocol
@ -1077,43 +1124,35 @@ class Integration:
def get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration."""
if platform := self._get_platform_cached(f"{self.domain}.{platform_name}"):
if platform := self._get_platform_cached_or_raise(
f"{self.domain}.{platform_name}"
):
return platform
return self._load_platform(platform_name)
def platform_exists(self, platform_name: str) -> bool | None:
"""Check if a platform exists for an integration.
def platforms_exists(self, platform_names: Iterable[str]) -> list[str]:
"""Check if a platforms exists for an integration.
Returns True if the platform exists, False if it does not.
If it cannot be determined if the platform exists without attempting
to import the component, it returns None. This will only happen
if this function is called before get_component or async_get_component
has been called for the integration or the integration failed to load.
This method is thread-safe and can be called from the executor
or event loop without doing blocking I/O.
"""
full_name = f"{self.domain}.{platform_name}"
cache = self._cache
if full_name in cache:
return True
files = self._top_level_files
domain = self.domain
existing_platforms: list[str] = []
missing_platforms = self._missing_platforms_cache
for platform_name in platform_names:
full_name = f"{domain}.{platform_name}"
if full_name not in missing_platforms and (
f"{platform_name}.py" in files or platform_name in files
):
existing_platforms.append(platform_name)
continue
missing_platforms[full_name] = ModuleNotFoundError(
f"Platform {full_name} not found",
name=f"{self.pkg_path}.{platform_name}",
)
if full_name in self._missing_platforms_cache:
return False
if not (component := cache.get(self.domain)) or not (
file := getattr(component, "__file__", None)
):
return None
path: pathlib.Path = pathlib.Path(file).parent.joinpath(platform_name)
if os.path.exists(path.with_suffix(".py")) or os.path.exists(path):
return True
exc = ModuleNotFoundError(
f"Platform {full_name} not found",
name=f"{self.pkg_path}.{platform_name}",
)
self._missing_platforms_cache[full_name] = exc
return False
return existing_platforms
def _load_platform(self, platform_name: str) -> ModuleType:
"""Load a platform for an integration.

View File

@ -1396,6 +1396,7 @@ def mock_integration(
else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}",
pathlib.Path(""),
module.mock_manifest(),
set(),
)
def mock_import_platform(platform_name: str) -> NoReturn:
@ -1423,13 +1424,14 @@ def mock_platform(
platform_path is in form hue.config_flow.
"""
domain = platform_path.split(".")[0]
domain, _, platform_name = platform_path.partition(".")
integration_cache = hass.data[loader.DATA_INTEGRATIONS]
module_cache = hass.data[loader.DATA_COMPONENTS]
if domain not in integration_cache:
mock_integration(hass, MockModule(domain))
integration_cache[domain]._top_level_files.add(f"{platform_name}.py")
_LOGGER.info("Adding mock integration platform: %s", platform_path)
module_cache[platform_path] = module or Mock()

View File

@ -1094,27 +1094,36 @@ async def test_async_get_component_preloads_config_and_config_flow(
assert "homeassistant.components.executor_import" not in sys.modules
assert "custom_components.executor_import" not in sys.modules
platform_exists_calls = []
def mock_platforms_exists(platforms: list[str]) -> bool:
platform_exists_calls.append(platforms)
return platforms
with patch(
"homeassistant.loader.importlib.import_module"
) as mock_import, patch.object(
executor_import_integration, "platform_exists", return_value=True
) as mock_platform_exists:
executor_import_integration, "platforms_exists", mock_platforms_exists
):
await executor_import_integration.async_get_component()
assert mock_platform_exists.call_count == 1
assert mock_import.call_count == 3
assert len(platform_exists_calls[0]) == len(loader.BASE_PRELOAD_PLATFORMS)
assert mock_import.call_count == 2 + len(loader.BASE_PRELOAD_PLATFORMS)
assert (
mock_import.call_args_list[0][0][0]
== "homeassistant.components.executor_import"
)
assert (
mock_import.call_args_list[1][0][0]
== "homeassistant.components.executor_import.config"
)
assert (
mock_import.call_args_list[2][0][0]
== "homeassistant.components.executor_import.config_flow"
)
checked_platforms = {
mock_import.call_args_list[i][0][0]
for i in range(1, len(mock_import.call_args_list))
}
assert checked_platforms == {
"homeassistant.components.executor_import.config_flow",
*(
f"homeassistant.components.executor_import.{platform}"
for platform in loader.BASE_PRELOAD_PLATFORMS
),
}
async def test_async_get_component_loads_loop_if_already_in_sys_modules(
@ -1134,8 +1143,8 @@ async def test_async_get_component_loads_loop_if_already_in_sys_modules(
assert "test_package_loaded_executor.config_flow" not in hass.config.components
config_flow_module_name = f"{integration.pkg_path}.config_flow"
module_mock = MagicMock()
config_flow_module_mock = MagicMock()
module_mock = MagicMock(__file__="__init__.py")
config_flow_module_mock = MagicMock(__file__="config_flow.py")
def import_module(name: str) -> Any:
if name == integration.pkg_path:
@ -1194,8 +1203,8 @@ async def test_async_get_component_concurrent_loads(
assert "test_package_loaded_executor.config_flow" not in hass.config.components
config_flow_module_name = f"{integration.pkg_path}.config_flow"
module_mock = MagicMock()
config_flow_module_mock = MagicMock()
module_mock = MagicMock(__file__="__init__.py")
config_flow_module_mock = MagicMock(__file__="config_flow.py")
imports = []
start_event = threading.Event()
import_event = asyncio.Event()
@ -1232,7 +1241,8 @@ async def test_async_get_component_concurrent_loads(
assert comp1 is module_mock
assert comp2 is module_mock
assert imports == [integration.pkg_path, config_flow_module_name]
assert integration.pkg_path in imports
assert config_flow_module_name in imports
async def test_async_get_component_deadlock_fallback(
@ -1243,7 +1253,7 @@ async def test_async_get_component_deadlock_fallback(
hass, "executor_import", True, import_executor=True
)
assert executor_import_integration.import_executor is True
module_mock = MagicMock()
module_mock = MagicMock(__file__="__init__.py")
import_attempts = 0
def mock_import(module: str, *args: Any, **kwargs: Any) -> Any:
@ -1395,38 +1405,51 @@ async def test_async_get_platform_raises_after_import_failure(
assert "loaded_executor=False" not in caplog.text
async def test_platform_exists(
async def test_platforms_exists(
hass: HomeAssistant, enable_custom_integrations: None
) -> None:
"""Test platform_exists."""
integration = await loader.async_get_integration(hass, "test_integration_platform")
assert integration.domain == "test_integration_platform"
"""Test platforms_exists."""
original_os_listdir = os.listdir
# get_component never called, will return None
assert integration.platform_exists("non_existing") is None
paths: list[str] = []
component = integration.get_component()
def mock_list_dir(path: str) -> list[str]:
paths.append(path)
return original_os_listdir(path)
with patch("homeassistant.loader.os.listdir", mock_list_dir):
integration = await loader.async_get_integration(
hass, "test_integration_platform"
)
assert integration.domain == "test_integration_platform"
# Verify the files cache is primed
assert integration.file_path in paths
# component is loaded, should now return False
with patch("homeassistant.loader.os.listdir", wraps=os.listdir) as mock_exists:
component = integration.get_component()
assert component.DOMAIN == "test_integration_platform"
# component is loaded, should now return False
with patch(
"homeassistant.loader.os.path.exists", wraps=os.path.exists
) as mock_exists:
assert integration.platform_exists("non_existing") is False
# We should check if the file exists
assert mock_exists.call_count == 2
# The files cache should be primed when
# the integration is resolved
assert mock_exists.call_count == 0
# component is loaded, should now return False
with patch(
"homeassistant.loader.os.path.exists", wraps=os.path.exists
) as mock_exists:
assert integration.platform_exists("non_existing") is False
with patch("homeassistant.loader.os.listdir", wraps=os.listdir) as mock_exists:
assert integration.platforms_exists(("non_existing",)) == []
# We should remember which files exist
assert mock_exists.call_count == 0
# component is loaded, should now return False
with patch("homeassistant.loader.os.listdir", wraps=os.listdir) as mock_exists:
assert integration.platforms_exists(("non_existing",)) == []
# We should remember the file does not exist
assert mock_exists.call_count == 0
assert integration.platform_exists("group") is True
assert integration.platforms_exists(["group"]) == ["group"]
platform = await integration.async_get_platform("group")
assert platform.MAGIC == 1
@ -1434,7 +1457,7 @@ async def test_platform_exists(
platform = integration.get_platform("group")
assert platform.MAGIC == 1
assert integration.platform_exists("group") is True
assert integration.platforms_exists(["group"]) == ["group"]
async def test_async_get_platforms_loads_loop_if_already_in_sys_modules(

View File

@ -843,7 +843,7 @@ async def test_async_prepare_setup_platform(
caplog.clear()
# There is no actual config platform for this integration
assert await setup.async_prepare_setup_platform(hass, {}, "config", "test") is None
assert "test.config not found" in caplog.text
assert "No module named 'custom_components.test.config'" in caplog.text
button_platform = (
await setup.async_prepare_setup_platform(hass, {}, "button", "test") is None