From 15618a8a974ef6dcb410ad7db5b0660aecca4d74 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 7 May 2024 18:37:01 +0200 Subject: [PATCH] Use HassKey for loader (#116999) --- homeassistant/loader.py | 45 +++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 9ecb468a8a8..3d201c1b694 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -39,6 +39,7 @@ from .generated.mqtt import MQTT from .generated.ssdp import SSDP from .generated.usb import USB from .generated.zeroconf import HOMEKIT, ZEROCONF +from .util.hass_dict import HassKey from .util.json import JSON_DECODE_EXCEPTIONS, json_loads if TYPE_CHECKING: @@ -98,11 +99,17 @@ BLOCKED_CUSTOM_INTEGRATIONS: dict[str, BlockedIntegration] = { ), } -DATA_COMPONENTS = "components" -DATA_INTEGRATIONS = "integrations" -DATA_MISSING_PLATFORMS = "missing_platforms" -DATA_CUSTOM_COMPONENTS = "custom_components" -DATA_PRELOAD_PLATFORMS = "preload_platforms" +DATA_COMPONENTS: HassKey[dict[str, ModuleType | ComponentProtocol]] = HassKey( + "components" +) +DATA_INTEGRATIONS: HassKey[dict[str, Integration | asyncio.Future[None]]] = HassKey( + "integrations" +) +DATA_MISSING_PLATFORMS: HassKey[dict[str, bool]] = HassKey("missing_platforms") +DATA_CUSTOM_COMPONENTS: HassKey[ + dict[str, Integration] | asyncio.Future[dict[str, Integration]] +] = HassKey("custom_components") +DATA_PRELOAD_PLATFORMS: HassKey[list[str]] = HassKey("preload_platforms") PACKAGE_CUSTOM_COMPONENTS = "custom_components" PACKAGE_BUILTIN = "homeassistant.components" CUSTOM_WARNING = ( @@ -298,9 +305,7 @@ async def async_get_custom_components( hass: HomeAssistant, ) -> dict[str, Integration]: """Return cached list of custom integrations.""" - comps_or_future: ( - dict[str, Integration] | asyncio.Future[dict[str, Integration]] | None - ) = hass.data.get(DATA_CUSTOM_COMPONENTS) + comps_or_future = hass.data.get(DATA_CUSTOM_COMPONENTS) if comps_or_future is None: future = hass.data[DATA_CUSTOM_COMPONENTS] = hass.loop.create_future() @@ -622,7 +627,7 @@ async def async_get_mqtt(hass: HomeAssistant) -> dict[str, list[str]]: @callback def async_register_preload_platform(hass: HomeAssistant, platform_name: str) -> None: """Register a platform to be preloaded.""" - preload_platforms: list[str] = hass.data[DATA_PRELOAD_PLATFORMS] + preload_platforms = hass.data[DATA_PRELOAD_PLATFORMS] if platform_name not in preload_platforms: preload_platforms.append(platform_name) @@ -746,14 +751,11 @@ class Integration: self._all_dependencies_resolved = True self._all_dependencies = set() - platforms_to_preload: list[str] = hass.data[DATA_PRELOAD_PLATFORMS] - self._platforms_to_preload = platforms_to_preload + self._platforms_to_preload = hass.data[DATA_PRELOAD_PLATFORMS] self._component_future: asyncio.Future[ComponentProtocol] | None = None self._import_futures: dict[str, asyncio.Future[ModuleType]] = {} - cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS] - self._cache = cache - missing_platforms_cache: dict[str, bool] = hass.data[DATA_MISSING_PLATFORMS] - self._missing_platforms_cache = missing_platforms_cache + self._cache = hass.data[DATA_COMPONENTS] + self._missing_platforms_cache = hass.data[DATA_MISSING_PLATFORMS] self._top_level_files = top_level_files or set() _LOGGER.info("Loaded %s from %s", self.domain, pkg_path) @@ -1233,7 +1235,7 @@ class Integration: appropriate locks. """ full_name = f"{self.domain}.{platform_name}" - cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS] + cache = self.hass.data[DATA_COMPONENTS] try: cache[full_name] = self._import_platform(platform_name) except ModuleNotFoundError: @@ -1259,7 +1261,7 @@ class Integration: f"Exception importing {self.pkg_path}.{platform_name}" ) from err - return cache[full_name] + return cast(ModuleType, cache[full_name]) def _import_platform(self, platform_name: str) -> ModuleType: """Import the platform. @@ -1311,8 +1313,6 @@ def async_get_loaded_integration(hass: HomeAssistant, domain: str) -> Integratio Raises IntegrationNotLoaded if the integration is not loaded. """ cache = hass.data[DATA_INTEGRATIONS] - if TYPE_CHECKING: - cache = cast(dict[str, Integration | asyncio.Future[None]], cache) int_or_fut = cache.get(domain, _UNDEF) # Integration is never subclassed, so we can check for type if type(int_or_fut) is Integration: @@ -1322,7 +1322,6 @@ def async_get_loaded_integration(hass: HomeAssistant, domain: str) -> Integratio async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration: """Get integration.""" - cache: dict[str, Integration | asyncio.Future[None]] cache = hass.data[DATA_INTEGRATIONS] if type(int_or_fut := cache.get(domain, _UNDEF)) is Integration: return int_or_fut @@ -1337,7 +1336,6 @@ async def async_get_integrations( hass: HomeAssistant, domains: Iterable[str] ) -> dict[str, Integration | Exception]: """Get integrations.""" - cache: dict[str, Integration | asyncio.Future[None]] cache = hass.data[DATA_INTEGRATIONS] results: dict[str, Integration | Exception] = {} needed: dict[str, asyncio.Future[None]] = {} @@ -1446,10 +1444,9 @@ def _load_file( Only returns it if also found to be valid. Async friendly. """ - cache: dict[str, ComponentProtocol] = hass.data[DATA_COMPONENTS] - module: ComponentProtocol | None + cache = hass.data[DATA_COMPONENTS] if module := cache.get(comp_or_platform): - return module + return cast(ComponentProtocol, module) for path in (f"{base}.{comp_or_platform}" for base in base_paths): try: