Use HassKey for helpers (2) (#117013)

This commit is contained in:
Marc Mueller 2024-05-07 18:24:13 +02:00 committed by GitHub
parent c50a340cbc
commit 8f614fb06d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 69 additions and 39 deletions

View File

@ -16,6 +16,7 @@ from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.importlib import async_import_module from homeassistant.helpers.importlib import async_import_module
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from homeassistant.util.hass_dict import HassKey
MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry() MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry()
@ -29,7 +30,7 @@ MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
DATA_REQS = "mfa_auth_module_reqs_processed" DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View File

@ -17,13 +17,14 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.importlib import async_import_module from homeassistant.helpers.importlib import async_import_module
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from homeassistant.util.hass_dict import HassKey
from ..auth_store import AuthStore from ..auth_store import AuthStore
from ..const import MFA_SESSION_EXPIRATION from ..const import MFA_SESSION_EXPIRATION
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_REQS = "auth_prov_reqs_processed" DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry() AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry()

View File

@ -32,6 +32,7 @@ from homeassistant.helpers.service import (
async_extract_referenced_entity_ids, async_extract_referenced_entity_ids,
async_register_admin_service, async_register_admin_service,
) )
from homeassistant.helpers.signal import KEY_HA_STOP
from homeassistant.helpers.template import async_load_custom_templates from homeassistant.helpers.template import async_load_custom_templates
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -386,7 +387,7 @@ async def _async_stop(hass: ha.HomeAssistant, restart: bool) -> None:
"""Stop home assistant.""" """Stop home assistant."""
exit_code = RESTART_EXIT_CODE if restart else 0 exit_code = RESTART_EXIT_CODE if restart else 0
# Track trask in hass.data. No need to cleanup, we're stopping. # Track trask in hass.data. No need to cleanup, we're stopping.
hass.data["homeassistant_stop"] = asyncio.create_task(hass.async_stop(exit_code)) hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code))
@ha.callback @ha.callback

View File

@ -20,10 +20,13 @@ from homeassistant.loader import (
bind_hass, bind_hass,
) )
from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded from homeassistant.setup import ATTR_COMPONENT, EventComponentLoaded
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.logging import catch_log_exception from homeassistant.util.logging import catch_log_exception
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_INTEGRATION_PLATFORMS = "integration_platforms" DATA_INTEGRATION_PLATFORMS: HassKey[list[IntegrationPlatform]] = HassKey(
"integration_platforms"
)
@dataclass(slots=True, frozen=True) @dataclass(slots=True, frozen=True)
@ -160,8 +163,7 @@ async def async_process_integration_platforms(
) -> None: ) -> None:
"""Process a specific platform for all current and future loaded integrations.""" """Process a specific platform for all current and future loaded integrations."""
if DATA_INTEGRATION_PLATFORMS not in hass.data: if DATA_INTEGRATION_PLATFORMS not in hass.data:
integration_platforms: list[IntegrationPlatform] = [] integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS] = []
hass.data[DATA_INTEGRATION_PLATFORMS] = integration_platforms
hass.bus.async_listen( hass.bus.async_listen(
EVENT_COMPONENT_LOADED, EVENT_COMPONENT_LOADED,
partial( partial(

View File

@ -1,12 +1,15 @@
"""Helpers to check recorder.""" """Helpers to check recorder."""
from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util.hass_dict import HassKey
DOMAIN = "recorder" DOMAIN: HassKey[RecorderData] = HassKey("recorder")
@dataclass(slots=True) @dataclass(slots=True)
@ -14,7 +17,7 @@ class RecorderData:
"""Recorder data stored in hass.data.""" """Recorder data stored in hass.data."""
recorder_platforms: dict[str, Any] = field(default_factory=dict) recorder_platforms: dict[str, Any] = field(default_factory=dict)
db_connected: asyncio.Future = field(default_factory=asyncio.Future) db_connected: asyncio.Future[bool] = field(default_factory=asyncio.Future)
def async_migration_in_progress(hass: HomeAssistant) -> bool: def async_migration_in_progress(hass: HomeAssistant) -> bool:
@ -40,5 +43,4 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
""" """
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return False return False
db_connected: asyncio.Future[bool] = hass.data[DOMAIN].db_connected return await hass.data[DOMAIN].db_connected
return await db_connected

View File

@ -11,6 +11,7 @@ from homeassistant.const import ATTR_RESTORED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant, State, callback, valid_entity_id from homeassistant.core import HomeAssistant, State, callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import json_loads from homeassistant.util.json import json_loads
from . import start from . import start
@ -20,7 +21,7 @@ from .frame import report
from .json import JSONEncoder from .json import JSONEncoder
from .storage import Store from .storage import Store
DATA_RESTORE_STATE = "restore_state" DATA_RESTORE_STATE: HassKey[RestoreStateData] = HassKey("restore_state")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -104,7 +105,7 @@ async def async_load(hass: HomeAssistant) -> None:
@callback @callback
def async_get(hass: HomeAssistant) -> RestoreStateData: def async_get(hass: HomeAssistant) -> RestoreStateData:
"""Get the restore state data helper.""" """Get the restore state data helper."""
return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE]) return hass.data[DATA_RESTORE_STATE]
class RestoreStateData: class RestoreStateData:

View File

@ -81,6 +81,7 @@ from homeassistant.core import (
from homeassistant.util import slugify from homeassistant.util import slugify
from homeassistant.util.async_ import create_eager_task from homeassistant.util.async_ import create_eager_task
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.signal_type import SignalType, SignalTypeFormat from homeassistant.util.signal_type import SignalType, SignalTypeFormat
from . import condition, config_validation as cv, service, template from . import condition, config_validation as cv, service, template
@ -133,9 +134,11 @@ DEFAULT_MAX_EXCEEDED = "WARNING"
ATTR_CUR = "current" ATTR_CUR = "current"
ATTR_MAX = "max" ATTR_MAX = "max"
DATA_SCRIPTS = "helpers.script" DATA_SCRIPTS: HassKey[list[ScriptData]] = HassKey("helpers.script")
DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints" DATA_SCRIPT_BREAKPOINTS: HassKey[dict[str, dict[str, set[str]]]] = HassKey(
DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED = "helpers.script_not_allowed" "helpers.script_breakpoints"
)
DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED: HassKey[None] = HassKey("helpers.script_not_allowed")
RUN_ID_ANY = "*" RUN_ID_ANY = "*"
NODE_ANY = "*" NODE_ANY = "*"
@ -158,6 +161,13 @@ SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None) script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
class ScriptData(TypedDict):
"""Store data related to script instance."""
instance: Script
started_before_shutdown: bool
class ScriptStoppedError(Exception): class ScriptStoppedError(Exception):
"""Error to indicate that the script has been stopped.""" """Error to indicate that the script has been stopped."""

View File

@ -47,6 +47,7 @@ from homeassistant.exceptions import (
) )
from homeassistant.loader import Integration, async_get_integrations, bind_hass from homeassistant.loader import Integration, async_get_integrations, bind_hass
from homeassistant.util.async_ import create_eager_task from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE from homeassistant.util.yaml.loader import JSON_TYPE
@ -74,8 +75,12 @@ CONF_SERVICE_ENTITY_ID = "entity_id"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache" SERVICE_DESCRIPTION_CACHE: HassKey[dict[tuple[str, str], dict[str, Any] | None]] = (
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache" HassKey("service_description_cache")
)
ALL_SERVICE_DESCRIPTIONS_CACHE: HassKey[
tuple[set[tuple[str, str]], dict[str, dict[str, Any]]]
] = HassKey("all_service_descriptions_cache")
_T = TypeVar("_T") _T = TypeVar("_T")
@ -660,9 +665,7 @@ async def async_get_all_descriptions(
hass: HomeAssistant, hass: HomeAssistant,
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls.""" """Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = ( descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
)
# We don't mutate services here so we avoid calling # We don't mutate services here so we avoid calling
# async_services which makes a copy of every services # async_services which makes a copy of every services
@ -686,7 +689,7 @@ async def async_get_all_descriptions(
previous_all_services, previous_descriptions_cache = all_cache previous_all_services, previous_descriptions_cache = all_cache
# If the services are the same, we can return the cache # If the services are the same, we can return the cache
if previous_all_services == all_services: if previous_all_services == all_services:
return previous_descriptions_cache # type: ignore[no-any-return] return previous_descriptions_cache
# Files we loaded for missing descriptions # Files we loaded for missing descriptions
loaded: dict[str, JSON_TYPE] = {} loaded: dict[str, JSON_TYPE] = {}
@ -812,9 +815,7 @@ def async_set_service_schema(
domain = domain.lower() domain = domain.lower()
service = service.lower() service = service.lower()
descriptions_cache: dict[tuple[str, str], dict[str, Any] | None] = ( descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
)
description = { description = {
"name": schema.get("name", ""), "name": schema.get("name", ""),

View File

@ -7,9 +7,12 @@ import signal
from homeassistant.const import RESTART_EXIT_CODE from homeassistant.const import RESTART_EXIT_CODE
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
KEY_HA_STOP: HassKey[asyncio.Task[None]] = HassKey("homeassistant_stop")
@callback @callback
@bind_hass @bind_hass
@ -25,9 +28,7 @@ def async_register_signal_handling(hass: HomeAssistant) -> None:
""" """
hass.loop.remove_signal_handler(signal.SIGTERM) hass.loop.remove_signal_handler(signal.SIGTERM)
hass.loop.remove_signal_handler(signal.SIGINT) hass.loop.remove_signal_handler(signal.SIGINT)
hass.data["homeassistant_stop"] = asyncio.create_task( hass.data[KEY_HA_STOP] = asyncio.create_task(hass.async_stop(exit_code))
hass.async_stop(exit_code)
)
try: try:
hass.loop.add_signal_handler(signal.SIGTERM, async_signal_handle, 0) hass.loop.add_signal_handler(signal.SIGTERM, async_signal_handle, 0)

View File

@ -32,6 +32,7 @@ from homeassistant.loader import bind_hass
from homeassistant.util import json as json_util from homeassistant.util import json as json_util
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.file import WriteError from homeassistant.util.file import WriteError
from homeassistant.util.hass_dict import HassKey
from . import json as json_helper from . import json as json_helper
@ -42,8 +43,8 @@ MAX_LOAD_CONCURRENTLY = 6
STORAGE_DIR = ".storage" STORAGE_DIR = ".storage"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STORAGE_SEMAPHORE = "storage_semaphore" STORAGE_SEMAPHORE: HassKey[asyncio.Semaphore] = HassKey("storage_semaphore")
STORAGE_MANAGER = "storage_manager" STORAGE_MANAGER: HassKey[_StoreManager] = HassKey("storage_manager")
MANAGER_CLEANUP_DELAY = 60 MANAGER_CLEANUP_DELAY = 60

View File

@ -10,12 +10,15 @@ from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING: if TYPE_CHECKING:
import astral import astral
import astral.location import astral.location
DATA_LOCATION_CACHE = "astral_location_cache" DATA_LOCATION_CACHE: HassKey[
dict[tuple[str, str, str, float, float], astral.location.Location]
] = HassKey("astral_location_cache")
ELEVATION_AGNOSTIC_EVENTS = ("noon", "midnight") ELEVATION_AGNOSTIC_EVENTS = ("noon", "midnight")

View File

@ -76,6 +76,7 @@ from homeassistant.util import (
slugify as slugify_util, slugify as slugify_util,
) )
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
from homeassistant.util.read_only_dict import ReadOnlyDict from homeassistant.util.read_only_dict import ReadOnlyDict
from homeassistant.util.thread import ThreadWithException from homeassistant.util.thread import ThreadWithException
@ -99,9 +100,13 @@ _LOGGER = logging.getLogger(__name__)
_SENTINEL = object() _SENTINEL = object()
DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S" DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S"
_ENVIRONMENT = "template.environment" _ENVIRONMENT: HassKey[TemplateEnvironment] = HassKey("template.environment")
_ENVIRONMENT_LIMITED = "template.environment_limited" _ENVIRONMENT_LIMITED: HassKey[TemplateEnvironment] = HassKey(
_ENVIRONMENT_STRICT = "template.environment_strict" "template.environment_limited"
)
_ENVIRONMENT_STRICT: HassKey[TemplateEnvironment] = HassKey(
"template.environment_strict"
)
_HASS_LOADER = "template.hass_loader" _HASS_LOADER = "template.hass_loader"
_RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{|\{#") _RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{|\{#")
@ -511,8 +516,7 @@ class Template:
wanted_env = _ENVIRONMENT_STRICT wanted_env = _ENVIRONMENT_STRICT
else: else:
wanted_env = _ENVIRONMENT wanted_env = _ENVIRONMENT
ret: TemplateEnvironment | None = self.hass.data.get(wanted_env) if (ret := self.hass.data.get(wanted_env)) is None:
if ret is None:
ret = self.hass.data[wanted_env] = TemplateEnvironment( ret = self.hass.data[wanted_env] = TemplateEnvironment(
self.hass, self._limited, self._strict, self._log_fn self.hass, self._limited, self._strict, self._log_fn
) )

View File

@ -30,6 +30,7 @@ from homeassistant.core import (
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.util.async_ import create_eager_task from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from .typing import ConfigType, TemplateVarsType from .typing import ConfigType, TemplateVarsType
@ -42,7 +43,9 @@ _PLATFORM_ALIASES = {
"time": "homeassistant", "time": "homeassistant",
} }
DATA_PLUGGABLE_ACTIONS = "pluggable_actions" DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = HassKey(
"pluggable_actions"
)
class TriggerProtocol(Protocol): class TriggerProtocol(Protocol):
@ -138,9 +141,8 @@ class PluggableAction:
def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]: def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]:
"""Return the pluggable actions registry.""" """Return the pluggable actions registry."""
if data := hass.data.get(DATA_PLUGGABLE_ACTIONS): if data := hass.data.get(DATA_PLUGGABLE_ACTIONS):
return data # type: ignore[no-any-return] return data
data = defaultdict(PluggableActionsEntry) data = hass.data[DATA_PLUGGABLE_ACTIONS] = defaultdict(PluggableActionsEntry)
hass.data[DATA_PLUGGABLE_ACTIONS] = data
return data return data
@staticmethod @staticmethod