From 2db64c7e6d953c9702ab0d4861c2ad30522687a5 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 7 May 2024 18:25:16 +0200 Subject: [PATCH] Use HassKey for helpers (1) (#117012) --- homeassistant/helpers/aiohttp_client.py | 21 +++++++-------- .../helpers/config_entry_oauth2_flow.py | 17 +++++++----- homeassistant/helpers/discovery_flow.py | 5 +++- homeassistant/helpers/entity_platform.py | 27 ++++++++++--------- homeassistant/helpers/event.py | 25 ++++++++++++----- homeassistant/helpers/httpx_client.py | 11 ++++---- homeassistant/helpers/icon.py | 5 ++-- homeassistant/helpers/intent.py | 5 ++-- 8 files changed, 68 insertions(+), 48 deletions(-) diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py index f5a1bb2e15f..5c4ead4e611 100644 --- a/homeassistant/helpers/aiohttp_client.py +++ b/homeassistant/helpers/aiohttp_client.py @@ -20,6 +20,7 @@ from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __v from homeassistant.core import Event, HomeAssistant, callback from homeassistant.loader import bind_hass from homeassistant.util import ssl as ssl_util +from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import json_loads from .backports.aiohttp_resolver import AsyncResolver @@ -30,8 +31,12 @@ if TYPE_CHECKING: from aiohttp.typedefs import JSONDecoder -DATA_CONNECTOR = "aiohttp_connector" -DATA_CLIENTSESSION = "aiohttp_clientsession" +DATA_CONNECTOR: HassKey[dict[tuple[bool, int], aiohttp.BaseConnector]] = HassKey( + "aiohttp_connector" +) +DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int], aiohttp.ClientSession]] = HassKey( + "aiohttp_clientsession" +) SERVER_SOFTWARE = ( f"{APPLICATION_NAME}/{__version__} " @@ -84,11 +89,7 @@ def async_get_clientsession( This method must be run in the event loop. """ session_key = _make_key(verify_ssl, family) - if DATA_CLIENTSESSION not in hass.data: - sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {} - hass.data[DATA_CLIENTSESSION] = sessions - else: - sessions = hass.data[DATA_CLIENTSESSION] + sessions = hass.data.setdefault(DATA_CLIENTSESSION, {}) if session_key not in sessions: session = _async_create_clientsession( @@ -288,11 +289,7 @@ def _async_get_connector( This method must be run in the event loop. """ connector_key = _make_key(verify_ssl, family) - if DATA_CONNECTOR not in hass.data: - connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {} - hass.data[DATA_CONNECTOR] = connectors - else: - connectors = hass.data[DATA_CONNECTOR] + connectors = hass.data.setdefault(DATA_CONNECTOR, {}) if connector_key in connectors: return connectors[connector_key] diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index caf47432623..f8395fa8b11 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -27,6 +27,7 @@ from homeassistant import config_entries from homeassistant.components import http from homeassistant.core import HomeAssistant, callback from homeassistant.loader import async_get_application_credentials +from homeassistant.util.hass_dict import HassKey from .aiohttp_client import async_get_clientsession from .network import NoURLAvailableError @@ -34,8 +35,15 @@ from .network import NoURLAvailableError _LOGGER = logging.getLogger(__name__) DATA_JWT_SECRET = "oauth2_jwt_secret" -DATA_IMPLEMENTATIONS = "oauth2_impl" -DATA_PROVIDERS = "oauth2_providers" +DATA_IMPLEMENTATIONS: HassKey[dict[str, dict[str, AbstractOAuth2Implementation]]] = ( + HassKey("oauth2_impl") +) +DATA_PROVIDERS: HassKey[ + dict[ + str, + Callable[[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]], + ] +] = HassKey("oauth2_providers") AUTH_CALLBACK_PATH = "/auth/external/callback" HEADER_FRONTEND_BASE = "HA-Frontend-Base" MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth" @@ -398,10 +406,7 @@ async def async_get_implementations( hass: HomeAssistant, domain: str ) -> dict[str, AbstractOAuth2Implementation]: """Return OAuth2 implementations for specified domain.""" - registered = cast( - dict[str, AbstractOAuth2Implementation], - hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}), - ) + registered = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}) if DATA_PROVIDERS not in hass.data: return registered diff --git a/homeassistant/helpers/discovery_flow.py b/homeassistant/helpers/discovery_flow.py index e479a47ecfd..b850a1b66fa 100644 --- a/homeassistant/helpers/discovery_flow.py +++ b/homeassistant/helpers/discovery_flow.py @@ -10,9 +10,12 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import CoreState, Event, HomeAssistant, callback from homeassistant.loader import bind_hass from homeassistant.util.async_ import gather_with_limited_concurrency +from homeassistant.util.hass_dict import HassKey FLOW_INIT_LIMIT = 20 -DISCOVERY_FLOW_DISPATCHER = "discovery_flow_dispatcher" +DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey( + "discovery_flow_dispatcher" +) @bind_hass diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 6d55417c05e..e49eff331b9 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -34,6 +34,7 @@ from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.generated import languages from homeassistant.setup import SetupPhases, async_start_setup from homeassistant.util.async_ import create_eager_task +from homeassistant.util.hass_dict import HassKey from . import ( config_validation as cv, @@ -57,9 +58,13 @@ SLOW_ADD_ENTITY_MAX_WAIT = 15 # Per Entity SLOW_ADD_MIN_TIMEOUT = 500 PLATFORM_NOT_READY_RETRIES = 10 -DATA_ENTITY_PLATFORM = "entity_platform" -DATA_DOMAIN_ENTITIES = "domain_entities" -DATA_DOMAIN_PLATFORM_ENTITIES = "domain_platform_entities" +DATA_ENTITY_PLATFORM: HassKey[dict[str, list[EntityPlatform]]] = HassKey( + "entity_platform" +) +DATA_DOMAIN_ENTITIES: HassKey[dict[str, dict[str, Entity]]] = HassKey("domain_entities") +DATA_DOMAIN_PLATFORM_ENTITIES: HassKey[dict[tuple[str, str], dict[str, Entity]]] = ( + HassKey("domain_platform_entities") +) PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds _LOGGER = getLogger(__name__) @@ -155,20 +160,18 @@ class EntityPlatform: # with the child dict indexed by entity_id # # This is usually media_player, light, switch, etc. - domain_entities: dict[str, dict[str, Entity]] = hass.data.setdefault( + self.domain_entities = hass.data.setdefault( DATA_DOMAIN_ENTITIES, {} - ) - self.domain_entities = domain_entities.setdefault(domain, {}) + ).setdefault(domain, {}) # Storage for entities indexed by domain and platform # with the child dict indexed by entity_id # # This is usually media_player.yamaha, light.hue, switch.tplink, etc. - domain_platform_entities: dict[tuple[str, str], dict[str, Entity]] = ( - hass.data.setdefault(DATA_DOMAIN_PLATFORM_ENTITIES, {}) - ) key = (domain, platform_name) - self.domain_platform_entities = domain_platform_entities.setdefault(key, {}) + self.domain_platform_entities = hass.data.setdefault( + DATA_DOMAIN_PLATFORM_ENTITIES, {} + ).setdefault(key, {}) def __repr__(self) -> str: """Represent an EntityPlatform.""" @@ -1063,6 +1066,4 @@ def async_get_platforms( ): return [] - platforms: list[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name] - - return platforms + return hass.data[DATA_ENTITY_PLATFORM][integration_name] diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index ace819a2734..0a2a8a93461 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -38,6 +38,7 @@ from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.event_type import EventType +from homeassistant.util.hass_dict import HassKey from . import frame from .device_registry import ( @@ -54,19 +55,29 @@ from .template import RenderInfo, Template, result_as_boolean from .typing import TemplateVarsType TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" -TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" +TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey( + "track_state_change_listener" +) TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks" -TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener" +TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey( + "track_state_added_domain_listener" +) TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks" -TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener" +TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey( + "track_state_removed_domain_listener" +) TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" -TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener" +TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey( + "track_entity_registry_updated_listener" +) TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks" -TRACK_DEVICE_REGISTRY_UPDATED_LISTENER = "track_device_registry_updated_listener" +TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey( + "track_device_registry_updated_listener" +) _ALL_LISTENER = "all" _DOMAINS_LISTENER = "domains" @@ -89,7 +100,7 @@ _P = ParamSpec("_P") class _KeyedEventTracker(Generic[_TypedDictT]): """Class to track events by key.""" - listeners_key: str + listeners_key: HassKey[Callable[[], None]] callbacks_key: str event_type: EventType[_TypedDictT] | str dispatcher_callable: Callable[ @@ -373,7 +384,7 @@ def _remove_empty_listener() -> None: @callback # type: ignore[arg-type] # mypy bug? def _remove_listener( hass: HomeAssistant, - listeners_key: str, + listeners_key: HassKey[Callable[[], None]], keys: Iterable[str], job: HassJob[[Event[_TypedDictT]], Any], callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]], diff --git a/homeassistant/helpers/httpx_client.py b/homeassistant/helpers/httpx_client.py index a0112ae0843..f71042e3057 100644 --- a/homeassistant/helpers/httpx_client.py +++ b/homeassistant/helpers/httpx_client.py @@ -11,6 +11,7 @@ import httpx from homeassistant.const import APPLICATION_NAME, EVENT_HOMEASSISTANT_CLOSE, __version__ from homeassistant.core import Event, HomeAssistant, callback from homeassistant.loader import bind_hass +from homeassistant.util.hass_dict import HassKey from homeassistant.util.ssl import ( SSLCipherList, client_context, @@ -23,8 +24,10 @@ from .frame import warn_use # and we want to keep the connection open for a while so we # don't have to reconnect every time so we use 15s to match aiohttp. KEEP_ALIVE_TIMEOUT = 15 -DATA_ASYNC_CLIENT = "httpx_async_client" -DATA_ASYNC_CLIENT_NOVERIFY = "httpx_async_client_noverify" +DATA_ASYNC_CLIENT: HassKey[httpx.AsyncClient] = HassKey("httpx_async_client") +DATA_ASYNC_CLIENT_NOVERIFY: HassKey[httpx.AsyncClient] = HassKey( + "httpx_async_client_noverify" +) DEFAULT_LIMITS = limits = httpx.Limits(keepalive_expiry=KEEP_ALIVE_TIMEOUT) SERVER_SOFTWARE = ( f"{APPLICATION_NAME}/{__version__} " @@ -42,9 +45,7 @@ def get_async_client(hass: HomeAssistant, verify_ssl: bool = True) -> httpx.Asyn """ key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY - client: httpx.AsyncClient | None = hass.data.get(key) - - if client is None: + if (client := hass.data.get(key)) is None: client = hass.data[key] = create_async_httpx_client(hass, verify_ssl) return client diff --git a/homeassistant/helpers/icon.py b/homeassistant/helpers/icon.py index db90d38744a..0f72dfbd3ab 100644 --- a/homeassistant/helpers/icon.py +++ b/homeassistant/helpers/icon.py @@ -11,11 +11,12 @@ from typing import Any from homeassistant.core import HomeAssistant, callback from homeassistant.loader import Integration, async_get_integrations +from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import load_json_object from .translation import build_resources -ICON_CACHE = "icon_cache" +ICON_CACHE: HassKey[_IconsCache] = HassKey("icon_cache") _LOGGER = logging.getLogger(__name__) @@ -142,7 +143,7 @@ async def async_get_icons( components = hass.config.top_level_components if ICON_CACHE in hass.data: - cache: _IconsCache = hass.data[ICON_CACHE] + cache = hass.data[ICON_CACHE] else: cache = hass.data[ICON_CACHE] = _IconsCache(hass) diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index 2a7d57dfd37..8d7f34007f8 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -23,6 +23,7 @@ from homeassistant.const import ( from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import bind_hass +from homeassistant.util.hass_dict import HassKey from . import ( area_registry, @@ -44,7 +45,7 @@ INTENT_SET_POSITION = "HassSetPosition" SLOT_SCHEMA = vol.Schema({}, extra=vol.ALLOW_EXTRA) -DATA_KEY = "intent" +DATA_KEY: HassKey[dict[str, IntentHandler]] = HassKey("intent") SPEECH_TYPE_PLAIN = "plain" SPEECH_TYPE_SSML = "ssml" @@ -89,7 +90,7 @@ async def async_handle( assistant: str | None = None, ) -> IntentResponse: """Handle an intent.""" - handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type) + handler = hass.data.get(DATA_KEY, {}).get(intent_type) if handler is None: raise UnknownIntent(f"Unknown intent {intent_type}")