From fd5885ec83914a2085390dfac376a75c648882ce Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 7 May 2024 18:03:14 +0200 Subject: [PATCH] Use HassKey for registries (#117000) --- homeassistant/helpers/area_registry.py | 7 ++++--- homeassistant/helpers/category_registry.py | 7 ++++--- homeassistant/helpers/device_registry.py | 7 ++++--- homeassistant/helpers/entity_registry.py | 7 ++++--- homeassistant/helpers/floor_registry.py | 7 ++++--- homeassistant/helpers/issue_registry.py | 5 +++-- homeassistant/helpers/label_registry.py | 7 ++++--- tests/helpers/test_issue_registry.py | 6 +++--- 8 files changed, 30 insertions(+), 23 deletions(-) diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 4dba510396f..96200c7b43a 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -4,11 +4,12 @@ from __future__ import annotations from collections.abc import Iterable import dataclasses -from typing import Any, Literal, TypedDict, cast +from typing import Any, Literal, TypedDict from homeassistant.core import HomeAssistant, callback from homeassistant.util import slugify from homeassistant.util.event_type import EventType +from homeassistant.util.hass_dict import HassKey from . import device_registry as dr, entity_registry as er from .normalized_name_base_registry import ( @@ -20,7 +21,7 @@ from .registry import BaseRegistry from .storage import Store from .typing import UNDEFINED, UndefinedType -DATA_REGISTRY = "area_registry" +DATA_REGISTRY: HassKey[AreaRegistry] = HassKey("area_registry") EVENT_AREA_REGISTRY_UPDATED: EventType[EventAreaRegistryUpdatedData] = EventType( "area_registry_updated" ) @@ -418,7 +419,7 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]): @callback def async_get(hass: HomeAssistant) -> AreaRegistry: """Get area registry.""" - return cast(AreaRegistry, hass.data[DATA_REGISTRY]) + return hass.data[DATA_REGISTRY] async def async_load(hass: HomeAssistant) -> None: diff --git a/homeassistant/helpers/category_registry.py b/homeassistant/helpers/category_registry.py index 4ae920055a2..dafb81d02ce 100644 --- a/homeassistant/helpers/category_registry.py +++ b/homeassistant/helpers/category_registry.py @@ -5,17 +5,18 @@ from __future__ import annotations from collections.abc import Iterable import dataclasses from dataclasses import dataclass, field -from typing import Literal, TypedDict, cast +from typing import Literal, TypedDict from homeassistant.core import Event, HomeAssistant, callback from homeassistant.util.event_type import EventType +from homeassistant.util.hass_dict import HassKey from homeassistant.util.ulid import ulid_now from .registry import BaseRegistry from .storage import Store from .typing import UNDEFINED, UndefinedType -DATA_REGISTRY = "category_registry" +DATA_REGISTRY: HassKey[CategoryRegistry] = HassKey("category_registry") EVENT_CATEGORY_REGISTRY_UPDATED: EventType[EventCategoryRegistryUpdatedData] = ( EventType("category_registry_updated") ) @@ -218,7 +219,7 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]): @callback def async_get(hass: HomeAssistant) -> CategoryRegistry: """Get category registry.""" - return cast(CategoryRegistry, hass.data[DATA_REGISTRY]) + return hass.data[DATA_REGISTRY] async def async_load(hass: HomeAssistant) -> None: diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 6b653784824..e32f2b77284 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -7,7 +7,7 @@ from enum import StrEnum from functools import cached_property, lru_cache, partial import logging import time -from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar import attr from yarl import URL @@ -23,6 +23,7 @@ from homeassistant.core import ( from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import async_suggest_report_issue from homeassistant.util.event_type import EventType +from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import format_unserializable_data import homeassistant.util.uuid as uuid_util @@ -46,7 +47,7 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) -DATA_REGISTRY = "device_registry" +DATA_REGISTRY: HassKey[DeviceRegistry] = HassKey("device_registry") EVENT_DEVICE_REGISTRY_UPDATED: EventType[EventDeviceRegistryUpdatedData] = EventType( "device_registry_updated" ) @@ -1078,7 +1079,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): @callback def async_get(hass: HomeAssistant) -> DeviceRegistry: """Get device registry.""" - return cast(DeviceRegistry, hass.data[DATA_REGISTRY]) + return hass.data[DATA_REGISTRY] async def async_load(hass: HomeAssistant) -> None: diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index c3bd3031750..ac41326ed95 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -16,7 +16,7 @@ from enum import StrEnum from functools import cached_property import logging import time -from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar import attr import voluptuous as vol @@ -48,6 +48,7 @@ from homeassistant.exceptions import MaxLengthExceeded from homeassistant.loader import async_suggest_report_issue from homeassistant.util import slugify, uuid as uuid_util from homeassistant.util.event_type import EventType +from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import format_unserializable_data from homeassistant.util.read_only_dict import ReadOnlyDict @@ -65,7 +66,7 @@ if TYPE_CHECKING: T = TypeVar("T") -DATA_REGISTRY = "entity_registry" +DATA_REGISTRY: HassKey[EntityRegistry] = HassKey("entity_registry") EVENT_ENTITY_REGISTRY_UPDATED: EventType[EventEntityRegistryUpdatedData] = EventType( "entity_registry_updated" ) @@ -1375,7 +1376,7 @@ class EntityRegistry(BaseRegistry): @callback def async_get(hass: HomeAssistant) -> EntityRegistry: """Get entity registry.""" - return cast(EntityRegistry, hass.data[DATA_REGISTRY]) + return hass.data[DATA_REGISTRY] async def async_load(hass: HomeAssistant) -> None: diff --git a/homeassistant/helpers/floor_registry.py b/homeassistant/helpers/floor_registry.py index 4a11d85176a..ad17d214b44 100644 --- a/homeassistant/helpers/floor_registry.py +++ b/homeassistant/helpers/floor_registry.py @@ -5,11 +5,12 @@ from __future__ import annotations from collections.abc import Iterable import dataclasses from dataclasses import dataclass -from typing import Literal, TypedDict, cast +from typing import Literal, TypedDict from homeassistant.core import Event, HomeAssistant, callback from homeassistant.util import slugify from homeassistant.util.event_type import EventType +from homeassistant.util.hass_dict import HassKey from .normalized_name_base_registry import ( NormalizedNameBaseRegistryEntry, @@ -20,7 +21,7 @@ from .registry import BaseRegistry from .storage import Store from .typing import UNDEFINED, UndefinedType -DATA_REGISTRY = "floor_registry" +DATA_REGISTRY: HassKey[FloorRegistry] = HassKey("floor_registry") EVENT_FLOOR_REGISTRY_UPDATED: EventType[EventFloorRegistryUpdatedData] = EventType( "floor_registry_updated" ) @@ -240,7 +241,7 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]): @callback def async_get(hass: HomeAssistant) -> FloorRegistry: """Get floor registry.""" - return cast(FloorRegistry, hass.data[DATA_REGISTRY]) + return hass.data[DATA_REGISTRY] async def async_load(hass: HomeAssistant) -> None: diff --git a/homeassistant/helpers/issue_registry.py b/homeassistant/helpers/issue_registry.py index 49dc2a36cb0..0b7ee6132a3 100644 --- a/homeassistant/helpers/issue_registry.py +++ b/homeassistant/helpers/issue_registry.py @@ -14,11 +14,12 @@ from homeassistant.const import __version__ as ha_version from homeassistant.core import HomeAssistant, callback from homeassistant.util.async_ import run_callback_threadsafe import homeassistant.util.dt as dt_util +from homeassistant.util.hass_dict import HassKey from .registry import BaseRegistry from .storage import Store -DATA_REGISTRY = "issue_registry" +DATA_REGISTRY: HassKey[IssueRegistry] = HassKey("issue_registry") EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED = "repairs_issue_registry_updated" STORAGE_KEY = "repairs.issue_registry" STORAGE_VERSION_MAJOR = 1 @@ -275,7 +276,7 @@ class IssueRegistry(BaseRegistry): @callback def async_get(hass: HomeAssistant) -> IssueRegistry: """Get issue registry.""" - return cast(IssueRegistry, hass.data[DATA_REGISTRY]) + return hass.data[DATA_REGISTRY] async def async_load(hass: HomeAssistant, *, read_only: bool = False) -> None: diff --git a/homeassistant/helpers/label_registry.py b/homeassistant/helpers/label_registry.py index 81901c71745..8be63257de3 100644 --- a/homeassistant/helpers/label_registry.py +++ b/homeassistant/helpers/label_registry.py @@ -5,11 +5,12 @@ from __future__ import annotations from collections.abc import Iterable import dataclasses from dataclasses import dataclass -from typing import Literal, TypedDict, cast +from typing import Literal, TypedDict from homeassistant.core import Event, HomeAssistant, callback from homeassistant.util import slugify from homeassistant.util.event_type import EventType +from homeassistant.util.hass_dict import HassKey from .normalized_name_base_registry import ( NormalizedNameBaseRegistryEntry, @@ -20,7 +21,7 @@ from .registry import BaseRegistry from .storage import Store from .typing import UNDEFINED, UndefinedType -DATA_REGISTRY = "label_registry" +DATA_REGISTRY: HassKey[LabelRegistry] = HassKey("label_registry") EVENT_LABEL_REGISTRY_UPDATED: EventType[EventLabelRegistryUpdatedData] = EventType( "label_registry_updated" ) @@ -241,7 +242,7 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]): @callback def async_get(hass: HomeAssistant) -> LabelRegistry: """Get label registry.""" - return cast(LabelRegistry, hass.data[DATA_REGISTRY]) + return hass.data[DATA_REGISTRY] async def async_load(hass: HomeAssistant) -> None: diff --git a/tests/helpers/test_issue_registry.py b/tests/helpers/test_issue_registry.py index eb6a32540e9..19644de8baf 100644 --- a/tests/helpers/test_issue_registry.py +++ b/tests/helpers/test_issue_registry.py @@ -161,7 +161,7 @@ async def test_load_save_issues(hass: HomeAssistant) -> None: "issue_id": "issue_3", } - registry: ir.IssueRegistry = hass.data[ir.DATA_REGISTRY] + registry = hass.data[ir.DATA_REGISTRY] assert len(registry.issues) == 3 issue1 = registry.async_get_issue("test", "issue_1") issue2 = registry.async_get_issue("test", "issue_2") @@ -327,7 +327,7 @@ async def test_loading_issues_from_storage( await ir.async_load(hass) - registry: ir.IssueRegistry = hass.data[ir.DATA_REGISTRY] + registry = hass.data[ir.DATA_REGISTRY] assert len(registry.issues) == 3 @@ -357,7 +357,7 @@ async def test_migration_1_1(hass: HomeAssistant, hass_storage: dict[str, Any]) await ir.async_load(hass) - registry: ir.IssueRegistry = hass.data[ir.DATA_REGISTRY] + registry = hass.data[ir.DATA_REGISTRY] assert len(registry.issues) == 2