Use HassKey for registries (#117000)

This commit is contained in:
Marc Mueller 2024-05-07 18:03:14 +02:00 committed by GitHub
parent 5ad52f122d
commit fd5885ec83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 30 additions and 23 deletions

View File

@ -4,11 +4,12 @@ from __future__ import annotations
from collections.abc import Iterable
import dataclasses
from typing import Any, Literal, TypedDict, cast
from typing import Any, Literal, TypedDict
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from . import device_registry as dr, entity_registry as er
from .normalized_name_base_registry import (
@ -20,7 +21,7 @@ from .registry import BaseRegistry
from .storage import Store
from .typing import UNDEFINED, UndefinedType
DATA_REGISTRY = "area_registry"
DATA_REGISTRY: HassKey[AreaRegistry] = HassKey("area_registry")
EVENT_AREA_REGISTRY_UPDATED: EventType[EventAreaRegistryUpdatedData] = EventType(
"area_registry_updated"
)
@ -418,7 +419,7 @@ class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
@callback
def async_get(hass: HomeAssistant) -> AreaRegistry:
"""Get area registry."""
return cast(AreaRegistry, hass.data[DATA_REGISTRY])
return hass.data[DATA_REGISTRY]
async def async_load(hass: HomeAssistant) -> None:

View File

@ -5,17 +5,18 @@ from __future__ import annotations
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass, field
from typing import Literal, TypedDict, cast
from typing import Literal, TypedDict
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.ulid import ulid_now
from .registry import BaseRegistry
from .storage import Store
from .typing import UNDEFINED, UndefinedType
DATA_REGISTRY = "category_registry"
DATA_REGISTRY: HassKey[CategoryRegistry] = HassKey("category_registry")
EVENT_CATEGORY_REGISTRY_UPDATED: EventType[EventCategoryRegistryUpdatedData] = (
EventType("category_registry_updated")
)
@ -218,7 +219,7 @@ class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
@callback
def async_get(hass: HomeAssistant) -> CategoryRegistry:
"""Get category registry."""
return cast(CategoryRegistry, hass.data[DATA_REGISTRY])
return hass.data[DATA_REGISTRY]
async def async_load(hass: HomeAssistant) -> None:

View File

@ -7,7 +7,7 @@ from enum import StrEnum
from functools import cached_property, lru_cache, partial
import logging
import time
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar
import attr
from yarl import URL
@ -23,6 +23,7 @@ from homeassistant.core import (
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import async_suggest_report_issue
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import format_unserializable_data
import homeassistant.util.uuid as uuid_util
@ -46,7 +47,7 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__)
DATA_REGISTRY = "device_registry"
DATA_REGISTRY: HassKey[DeviceRegistry] = HassKey("device_registry")
EVENT_DEVICE_REGISTRY_UPDATED: EventType[EventDeviceRegistryUpdatedData] = EventType(
"device_registry_updated"
)
@ -1078,7 +1079,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
@callback
def async_get(hass: HomeAssistant) -> DeviceRegistry:
"""Get device registry."""
return cast(DeviceRegistry, hass.data[DATA_REGISTRY])
return hass.data[DATA_REGISTRY]
async def async_load(hass: HomeAssistant) -> None:

View File

@ -16,7 +16,7 @@ from enum import StrEnum
from functools import cached_property
import logging
import time
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar
import attr
import voluptuous as vol
@ -48,6 +48,7 @@ from homeassistant.exceptions import MaxLengthExceeded
from homeassistant.loader import async_suggest_report_issue
from homeassistant.util import slugify, uuid as uuid_util
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import format_unserializable_data
from homeassistant.util.read_only_dict import ReadOnlyDict
@ -65,7 +66,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
DATA_REGISTRY = "entity_registry"
DATA_REGISTRY: HassKey[EntityRegistry] = HassKey("entity_registry")
EVENT_ENTITY_REGISTRY_UPDATED: EventType[EventEntityRegistryUpdatedData] = EventType(
"entity_registry_updated"
)
@ -1375,7 +1376,7 @@ class EntityRegistry(BaseRegistry):
@callback
def async_get(hass: HomeAssistant) -> EntityRegistry:
"""Get entity registry."""
return cast(EntityRegistry, hass.data[DATA_REGISTRY])
return hass.data[DATA_REGISTRY]
async def async_load(hass: HomeAssistant) -> None:

View File

@ -5,11 +5,12 @@ from __future__ import annotations
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
from typing import Literal, TypedDict, cast
from typing import Literal, TypedDict
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
@ -20,7 +21,7 @@ from .registry import BaseRegistry
from .storage import Store
from .typing import UNDEFINED, UndefinedType
DATA_REGISTRY = "floor_registry"
DATA_REGISTRY: HassKey[FloorRegistry] = HassKey("floor_registry")
EVENT_FLOOR_REGISTRY_UPDATED: EventType[EventFloorRegistryUpdatedData] = EventType(
"floor_registry_updated"
)
@ -240,7 +241,7 @@ class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
@callback
def async_get(hass: HomeAssistant) -> FloorRegistry:
"""Get floor registry."""
return cast(FloorRegistry, hass.data[DATA_REGISTRY])
return hass.data[DATA_REGISTRY]
async def async_load(hass: HomeAssistant) -> None:

View File

@ -14,11 +14,12 @@ from homeassistant.const import __version__ as ha_version
from homeassistant.core import HomeAssistant, callback
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as dt_util
from homeassistant.util.hass_dict import HassKey
from .registry import BaseRegistry
from .storage import Store
DATA_REGISTRY = "issue_registry"
DATA_REGISTRY: HassKey[IssueRegistry] = HassKey("issue_registry")
EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED = "repairs_issue_registry_updated"
STORAGE_KEY = "repairs.issue_registry"
STORAGE_VERSION_MAJOR = 1
@ -275,7 +276,7 @@ class IssueRegistry(BaseRegistry):
@callback
def async_get(hass: HomeAssistant) -> IssueRegistry:
"""Get issue registry."""
return cast(IssueRegistry, hass.data[DATA_REGISTRY])
return hass.data[DATA_REGISTRY]
async def async_load(hass: HomeAssistant, *, read_only: bool = False) -> None:

View File

@ -5,11 +5,12 @@ from __future__ import annotations
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
from typing import Literal, TypedDict, cast
from typing import Literal, TypedDict
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify
from homeassistant.util.event_type import EventType
from homeassistant.util.hass_dict import HassKey
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
@ -20,7 +21,7 @@ from .registry import BaseRegistry
from .storage import Store
from .typing import UNDEFINED, UndefinedType
DATA_REGISTRY = "label_registry"
DATA_REGISTRY: HassKey[LabelRegistry] = HassKey("label_registry")
EVENT_LABEL_REGISTRY_UPDATED: EventType[EventLabelRegistryUpdatedData] = EventType(
"label_registry_updated"
)
@ -241,7 +242,7 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
@callback
def async_get(hass: HomeAssistant) -> LabelRegistry:
"""Get label registry."""
return cast(LabelRegistry, hass.data[DATA_REGISTRY])
return hass.data[DATA_REGISTRY]
async def async_load(hass: HomeAssistant) -> None:

View File

@ -161,7 +161,7 @@ async def test_load_save_issues(hass: HomeAssistant) -> None:
"issue_id": "issue_3",
}
registry: ir.IssueRegistry = hass.data[ir.DATA_REGISTRY]
registry = hass.data[ir.DATA_REGISTRY]
assert len(registry.issues) == 3
issue1 = registry.async_get_issue("test", "issue_1")
issue2 = registry.async_get_issue("test", "issue_2")
@ -327,7 +327,7 @@ async def test_loading_issues_from_storage(
await ir.async_load(hass)
registry: ir.IssueRegistry = hass.data[ir.DATA_REGISTRY]
registry = hass.data[ir.DATA_REGISTRY]
assert len(registry.issues) == 3
@ -357,7 +357,7 @@ async def test_migration_1_1(hass: HomeAssistant, hass_storage: dict[str, Any])
await ir.async_load(hass)
registry: ir.IssueRegistry = hass.data[ir.DATA_REGISTRY]
registry = hass.data[ir.DATA_REGISTRY]
assert len(registry.issues) == 2