Use HassKey for loader (#116999)

This commit is contained in:
Marc Mueller 2024-05-07 18:37:01 +02:00 committed by GitHub
parent b21632ad05
commit 15618a8a97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,6 +39,7 @@ from .generated.mqtt import MQTT
from .generated.ssdp import SSDP from .generated.ssdp import SSDP
from .generated.usb import USB from .generated.usb import USB
from .generated.zeroconf import HOMEKIT, ZEROCONF from .generated.zeroconf import HOMEKIT, ZEROCONF
from .util.hass_dict import HassKey
from .util.json import JSON_DECODE_EXCEPTIONS, json_loads from .util.json import JSON_DECODE_EXCEPTIONS, json_loads
if TYPE_CHECKING: if TYPE_CHECKING:
@ -98,11 +99,17 @@ BLOCKED_CUSTOM_INTEGRATIONS: dict[str, BlockedIntegration] = {
), ),
} }
DATA_COMPONENTS = "components" DATA_COMPONENTS: HassKey[dict[str, ModuleType | ComponentProtocol]] = HassKey(
DATA_INTEGRATIONS = "integrations" "components"
DATA_MISSING_PLATFORMS = "missing_platforms" )
DATA_CUSTOM_COMPONENTS = "custom_components" DATA_INTEGRATIONS: HassKey[dict[str, Integration | asyncio.Future[None]]] = HassKey(
DATA_PRELOAD_PLATFORMS = "preload_platforms" "integrations"
)
DATA_MISSING_PLATFORMS: HassKey[dict[str, bool]] = HassKey("missing_platforms")
DATA_CUSTOM_COMPONENTS: HassKey[
dict[str, Integration] | asyncio.Future[dict[str, Integration]]
] = HassKey("custom_components")
DATA_PRELOAD_PLATFORMS: HassKey[list[str]] = HassKey("preload_platforms")
PACKAGE_CUSTOM_COMPONENTS = "custom_components" PACKAGE_CUSTOM_COMPONENTS = "custom_components"
PACKAGE_BUILTIN = "homeassistant.components" PACKAGE_BUILTIN = "homeassistant.components"
CUSTOM_WARNING = ( CUSTOM_WARNING = (
@ -298,9 +305,7 @@ async def async_get_custom_components(
hass: HomeAssistant, hass: HomeAssistant,
) -> dict[str, Integration]: ) -> dict[str, Integration]:
"""Return cached list of custom integrations.""" """Return cached list of custom integrations."""
comps_or_future: ( comps_or_future = hass.data.get(DATA_CUSTOM_COMPONENTS)
dict[str, Integration] | asyncio.Future[dict[str, Integration]] | None
) = hass.data.get(DATA_CUSTOM_COMPONENTS)
if comps_or_future is None: if comps_or_future is None:
future = hass.data[DATA_CUSTOM_COMPONENTS] = hass.loop.create_future() future = hass.data[DATA_CUSTOM_COMPONENTS] = hass.loop.create_future()
@ -622,7 +627,7 @@ async def async_get_mqtt(hass: HomeAssistant) -> dict[str, list[str]]:
@callback @callback
def async_register_preload_platform(hass: HomeAssistant, platform_name: str) -> None: def async_register_preload_platform(hass: HomeAssistant, platform_name: str) -> None:
"""Register a platform to be preloaded.""" """Register a platform to be preloaded."""
preload_platforms: list[str] = hass.data[DATA_PRELOAD_PLATFORMS] preload_platforms = hass.data[DATA_PRELOAD_PLATFORMS]
if platform_name not in preload_platforms: if platform_name not in preload_platforms:
preload_platforms.append(platform_name) preload_platforms.append(platform_name)
@ -746,14 +751,11 @@ class Integration:
self._all_dependencies_resolved = True self._all_dependencies_resolved = True
self._all_dependencies = set() self._all_dependencies = set()
platforms_to_preload: list[str] = hass.data[DATA_PRELOAD_PLATFORMS] self._platforms_to_preload = hass.data[DATA_PRELOAD_PLATFORMS]
self._platforms_to_preload = platforms_to_preload
self._component_future: asyncio.Future[ComponentProtocol] | None = None self._component_future: asyncio.Future[ComponentProtocol] | None = None
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {} self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS] self._cache = hass.data[DATA_COMPONENTS]
self._cache = cache self._missing_platforms_cache = hass.data[DATA_MISSING_PLATFORMS]
missing_platforms_cache: dict[str, bool] = hass.data[DATA_MISSING_PLATFORMS]
self._missing_platforms_cache = missing_platforms_cache
self._top_level_files = top_level_files or set() self._top_level_files = top_level_files or set()
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path) _LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
@ -1233,7 +1235,7 @@ class Integration:
appropriate locks. appropriate locks.
""" """
full_name = f"{self.domain}.{platform_name}" full_name = f"{self.domain}.{platform_name}"
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS] cache = self.hass.data[DATA_COMPONENTS]
try: try:
cache[full_name] = self._import_platform(platform_name) cache[full_name] = self._import_platform(platform_name)
except ModuleNotFoundError: except ModuleNotFoundError:
@ -1259,7 +1261,7 @@ class Integration:
f"Exception importing {self.pkg_path}.{platform_name}" f"Exception importing {self.pkg_path}.{platform_name}"
) from err ) from err
return cache[full_name] return cast(ModuleType, cache[full_name])
def _import_platform(self, platform_name: str) -> ModuleType: def _import_platform(self, platform_name: str) -> ModuleType:
"""Import the platform. """Import the platform.
@ -1311,8 +1313,6 @@ def async_get_loaded_integration(hass: HomeAssistant, domain: str) -> Integratio
Raises IntegrationNotLoaded if the integration is not loaded. Raises IntegrationNotLoaded if the integration is not loaded.
""" """
cache = hass.data[DATA_INTEGRATIONS] cache = hass.data[DATA_INTEGRATIONS]
if TYPE_CHECKING:
cache = cast(dict[str, Integration | asyncio.Future[None]], cache)
int_or_fut = cache.get(domain, _UNDEF) int_or_fut = cache.get(domain, _UNDEF)
# Integration is never subclassed, so we can check for type # Integration is never subclassed, so we can check for type
if type(int_or_fut) is Integration: if type(int_or_fut) is Integration:
@ -1322,7 +1322,6 @@ def async_get_loaded_integration(hass: HomeAssistant, domain: str) -> Integratio
async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration: async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration:
"""Get integration.""" """Get integration."""
cache: dict[str, Integration | asyncio.Future[None]]
cache = hass.data[DATA_INTEGRATIONS] cache = hass.data[DATA_INTEGRATIONS]
if type(int_or_fut := cache.get(domain, _UNDEF)) is Integration: if type(int_or_fut := cache.get(domain, _UNDEF)) is Integration:
return int_or_fut return int_or_fut
@ -1337,7 +1336,6 @@ async def async_get_integrations(
hass: HomeAssistant, domains: Iterable[str] hass: HomeAssistant, domains: Iterable[str]
) -> dict[str, Integration | Exception]: ) -> dict[str, Integration | Exception]:
"""Get integrations.""" """Get integrations."""
cache: dict[str, Integration | asyncio.Future[None]]
cache = hass.data[DATA_INTEGRATIONS] cache = hass.data[DATA_INTEGRATIONS]
results: dict[str, Integration | Exception] = {} results: dict[str, Integration | Exception] = {}
needed: dict[str, asyncio.Future[None]] = {} needed: dict[str, asyncio.Future[None]] = {}
@ -1446,10 +1444,9 @@ def _load_file(
Only returns it if also found to be valid. Only returns it if also found to be valid.
Async friendly. Async friendly.
""" """
cache: dict[str, ComponentProtocol] = hass.data[DATA_COMPONENTS] cache = hass.data[DATA_COMPONENTS]
module: ComponentProtocol | None
if module := cache.get(comp_or_platform): if module := cache.get(comp_or_platform):
return module return cast(ComponentProtocol, module)
for path in (f"{base}.{comp_or_platform}" for base in base_paths): for path in (f"{base}.{comp_or_platform}" for base in base_paths):
try: try: