From 16900dcef15bdb9016feabd12bfec94d61ed4df6 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Sat, 9 Jul 2022 22:32:57 +0200 Subject: [PATCH] Make Store a generic class (#74617) --- homeassistant/auth/auth_store.py | 5 +++-- homeassistant/auth/mfa_modules/notify.py | 10 ++++------ homeassistant/auth/mfa_modules/totp.py | 12 +++++------- homeassistant/auth/providers/homeassistant.py | 13 ++++++------- homeassistant/components/almond/__init__.py | 6 +++--- .../components/ambiclimate/climate.py | 2 +- .../components/analytics/analytics.py | 8 ++++---- homeassistant/components/camera/prefs.py | 8 ++++---- homeassistant/components/energy/data.py | 10 ++++++---- homeassistant/components/hassio/__init__.py | 4 +--- .../components/homekit_controller/storage.py | 13 +++++++------ homeassistant/components/http/__init__.py | 12 +++++++----- homeassistant/components/http/auth.py | 6 +++--- .../components/mobile_app/__init__.py | 3 ++- homeassistant/components/nest/media_source.py | 15 +++++---------- homeassistant/components/network/network.py | 10 ++++++---- .../resolution_center/issue_registry.py | 7 +++++-- .../components/smartthings/smartapp.py | 7 ++++--- homeassistant/components/trace/__init__.py | 4 +++- homeassistant/components/zha/core/store.py | 4 ++-- homeassistant/config_entries.py | 4 +++- homeassistant/core.py | 6 +++--- homeassistant/helpers/area_registry.py | 9 ++++++--- homeassistant/helpers/device_registry.py | 3 +-- homeassistant/helpers/instance_id.py | 2 +- homeassistant/helpers/restore_state.py | 2 +- homeassistant/helpers/storage.py | 18 ++++++++++-------- 27 files changed, 106 insertions(+), 97 deletions(-) diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index baf5a8bf3b3..2597781dc60 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -46,7 +46,7 @@ class AuthStore: self._users: dict[str, models.User] | None = None self._groups: dict[str, models.Group] | None = None self._perm_lookup: PermissionLookup | None = None - self._store = Store( + self._store = Store[dict[str, list[dict[str, Any]]]]( hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True ) self._lock = asyncio.Lock() @@ -483,9 +483,10 @@ class AuthStore: jwt_key=rt_dict["jwt_key"], last_used_at=last_used_at, last_used_ip=rt_dict.get("last_used_ip"), - credential=credentials.get(rt_dict.get("credential_id")), version=rt_dict.get("version"), ) + if "credential_id" in rt_dict: + token.credential = credentials.get(rt_dict["credential_id"]) users[rt_dict["user_id"]].refresh_tokens[token.id] = token self._groups = groups diff --git a/homeassistant/auth/mfa_modules/notify.py b/homeassistant/auth/mfa_modules/notify.py index 3872257a205..464ce495050 100644 --- a/homeassistant/auth/mfa_modules/notify.py +++ b/homeassistant/auth/mfa_modules/notify.py @@ -7,7 +7,7 @@ from __future__ import annotations import asyncio from collections import OrderedDict import logging -from typing import Any +from typing import Any, cast import attr import voluptuous as vol @@ -100,7 +100,7 @@ class NotifyAuthModule(MultiFactorAuthModule): """Initialize the user data store.""" super().__init__(hass, config) self._user_settings: _UsersDict | None = None - self._user_store = Store( + self._user_store = Store[dict[str, dict[str, Any]]]( hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True ) self._include = config.get(CONF_INCLUDE, []) @@ -119,10 +119,8 @@ class NotifyAuthModule(MultiFactorAuthModule): if self._user_settings is not None: return - if (data := await self._user_store.async_load()) is None or not isinstance( - data, dict - ): - data = {STORAGE_USERS: {}} + if (data := await self._user_store.async_load()) is None: + data = cast(dict[str, dict[str, Any]], {STORAGE_USERS: {}}) self._user_settings = { user_id: NotifySetting(**setting) diff --git a/homeassistant/auth/mfa_modules/totp.py b/homeassistant/auth/mfa_modules/totp.py index e503198f08b..397a7fcd386 100644 --- a/homeassistant/auth/mfa_modules/totp.py +++ b/homeassistant/auth/mfa_modules/totp.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from io import BytesIO -from typing import Any +from typing import Any, cast import voluptuous as vol @@ -77,7 +77,7 @@ class TotpAuthModule(MultiFactorAuthModule): """Initialize the user data store.""" super().__init__(hass, config) self._users: dict[str, str] | None = None - self._user_store = Store( + self._user_store = Store[dict[str, dict[str, str]]]( hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True ) self._init_lock = asyncio.Lock() @@ -93,16 +93,14 @@ class TotpAuthModule(MultiFactorAuthModule): if self._users is not None: return - if (data := await self._user_store.async_load()) is None or not isinstance( - data, dict - ): - data = {STORAGE_USERS: {}} + if (data := await self._user_store.async_load()) is None: + data = cast(dict[str, dict[str, str]], {STORAGE_USERS: {}}) self._users = data.get(STORAGE_USERS, {}) async def _async_save(self) -> None: """Save data.""" - await self._user_store.async_save({STORAGE_USERS: self._users}) + await self._user_store.async_save({STORAGE_USERS: self._users or {}}) def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str: """Create a ota_secret for user.""" diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index cb95907c9b2..d190a618596 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -61,10 +61,10 @@ class Data: def __init__(self, hass: HomeAssistant) -> None: """Initialize the user data store.""" self.hass = hass - self._store = Store( + self._store = Store[dict[str, list[dict[str, str]]]]( hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True ) - self._data: dict[str, Any] | None = None + self._data: dict[str, list[dict[str, str]]] | None = None # Legacy mode will allow usernames to start/end with whitespace # and will compare usernames case-insensitive. # Remove in 2020 or when we launch 1.0. @@ -80,10 +80,8 @@ class Data: async def async_load(self) -> None: """Load stored data.""" - if (data := await self._store.async_load()) is None or not isinstance( - data, dict - ): - data = {"users": []} + if (data := await self._store.async_load()) is None: + data = cast(dict[str, list[dict[str, str]]], {"users": []}) seen: set[str] = set() @@ -123,7 +121,8 @@ class Data: @property def users(self) -> list[dict[str, str]]: """Return users.""" - return self._data["users"] # type: ignore[index,no-any-return] + assert self._data is not None + return self._data["users"] def validate_login(self, username: str, password: str) -> None: """Validate a username and password. diff --git a/homeassistant/components/almond/__init__.py b/homeassistant/components/almond/__init__.py index 15c280d9c1e..09ff85491ba 100644 --- a/homeassistant/components/almond/__init__.py +++ b/homeassistant/components/almond/__init__.py @@ -5,7 +5,7 @@ import asyncio from datetime import timedelta import logging import time -from typing import Optional, cast +from typing import Any from aiohttp import ClientError, ClientSession import async_timeout @@ -167,8 +167,8 @@ async def _configure_almond_for_ha( return _LOGGER.debug("Configuring Almond to connect to Home Assistant at %s", hass_url) - store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) - data = cast(Optional[dict], await store.async_load()) + store = storage.Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) + data = await store.async_load() if data is None: data = {} diff --git a/homeassistant/components/ambiclimate/climate.py b/homeassistant/components/ambiclimate/climate.py index 93d9348655f..50135693ff4 100644 --- a/homeassistant/components/ambiclimate/climate.py +++ b/homeassistant/components/ambiclimate/climate.py @@ -64,7 +64,7 @@ async def async_setup_entry( """Set up the Ambiclimate device from config entry.""" config = entry.data websession = async_get_clientsession(hass) - store = Store(hass, STORAGE_VERSION, STORAGE_KEY) + store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) token_info = await store.async_load() oauth = ambiclimate.AmbiclimateOAuth( diff --git a/homeassistant/components/analytics/analytics.py b/homeassistant/components/analytics/analytics.py index 802aa33585a..5bb0368b021 100644 --- a/homeassistant/components/analytics/analytics.py +++ b/homeassistant/components/analytics/analytics.py @@ -1,6 +1,6 @@ """Analytics helper class for the analytics integration.""" import asyncio -from typing import cast +from typing import Any import uuid import aiohttp @@ -66,12 +66,12 @@ class Analytics: """Initialize the Analytics class.""" self.hass: HomeAssistant = hass self.session = async_get_clientsession(hass) - self._data: dict = { + self._data: dict[str, Any] = { ATTR_PREFERENCES: {}, ATTR_ONBOARDED: False, ATTR_UUID: None, } - self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY) + self._store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) @property def preferences(self) -> dict: @@ -109,7 +109,7 @@ class Analytics: async def load(self) -> None: """Load preferences.""" - stored = cast(dict, await self._store.async_load()) + stored = await self._store.async_load() if stored: self._data = stored diff --git a/homeassistant/components/camera/prefs.py b/homeassistant/components/camera/prefs.py index 3d54c10d09a..08c57631a1b 100644 --- a/homeassistant/components/camera/prefs.py +++ b/homeassistant/components/camera/prefs.py @@ -36,14 +36,14 @@ class CameraPreferences: def __init__(self, hass: HomeAssistant) -> None: """Initialize camera prefs.""" self._hass = hass - self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY) + self._store = Store[dict[str, dict[str, bool]]]( + hass, STORAGE_VERSION, STORAGE_KEY + ) self._prefs: dict[str, dict[str, bool]] | None = None async def async_initialize(self) -> None: """Finish initializing the preferences.""" - if (prefs := await self._store.async_load()) is None or not isinstance( - prefs, dict - ): + if (prefs := await self._store.async_load()) is None: prefs = {} self._prefs = prefs diff --git a/homeassistant/components/energy/data.py b/homeassistant/components/energy/data.py index d33f915628d..e8c62da0c3c 100644 --- a/homeassistant/components/energy/data.py +++ b/homeassistant/components/energy/data.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from collections import Counter from collections.abc import Awaitable, Callable -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Literal, TypedDict, Union import voluptuous as vol @@ -263,13 +263,15 @@ class EnergyManager: def __init__(self, hass: HomeAssistant) -> None: """Initialize energy manager.""" self._hass = hass - self._store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) + self._store = storage.Store[EnergyPreferences]( + hass, STORAGE_VERSION, STORAGE_KEY + ) self.data: EnergyPreferences | None = None self._update_listeners: list[Callable[[], Awaitable]] = [] async def async_initialize(self) -> None: """Initialize the energy integration.""" - self.data = cast(Optional[EnergyPreferences], await self._store.async_load()) + self.data = await self._store.async_load() @staticmethod def default_preferences() -> EnergyPreferences: @@ -294,7 +296,7 @@ class EnergyManager: data[key] = update[key] # type: ignore[literal-required] self.data = data - self._store.async_delay_save(lambda: cast(dict, self.data), 60) + self._store.async_delay_save(lambda: data, 60) if not self._update_listeners: return diff --git a/homeassistant/components/hassio/__init__.py b/homeassistant/components/hassio/__init__.py index d580847646d..46592cbc20c 100644 --- a/homeassistant/components/hassio/__init__.py +++ b/homeassistant/components/hassio/__init__.py @@ -533,12 +533,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa: if not await hassio.is_connected(): _LOGGER.warning("Not connected with the supervisor / system too busy!") - store = Store(hass, STORAGE_VERSION, STORAGE_KEY) + store = Store[dict[str, str]](hass, STORAGE_VERSION, STORAGE_KEY) if (data := await store.async_load()) is None: data = {} - assert isinstance(data, dict) - refresh_token = None if "hassio_user" in data: user = await hass.auth.async_get_user(data["hassio_user"]) diff --git a/homeassistant/components/homekit_controller/storage.py b/homeassistant/components/homekit_controller/storage.py index 9372764a88a..ff39c52627e 100644 --- a/homeassistant/components/homekit_controller/storage.py +++ b/homeassistant/components/homekit_controller/storage.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, TypedDict, cast +from typing import Any, TypedDict from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.storage import Store @@ -46,7 +46,9 @@ class EntityMapStorage: def __init__(self, hass: HomeAssistant) -> None: """Create a new entity map store.""" self.hass = hass - self.store = Store(hass, ENTITY_MAP_STORAGE_VERSION, ENTITY_MAP_STORAGE_KEY) + self.store = Store[StorageLayout]( + hass, ENTITY_MAP_STORAGE_VERSION, ENTITY_MAP_STORAGE_KEY + ) self.storage_data: dict[str, Pairing] = {} async def async_initialize(self) -> None: @@ -55,8 +57,7 @@ class EntityMapStorage: # There is no cached data about HomeKit devices yet return - storage = cast(StorageLayout, raw_storage) - self.storage_data = storage.get("pairings", {}) + self.storage_data = raw_storage.get("pairings", {}) def get_map(self, homekit_id: str) -> Pairing | None: """Get a pairing cache item.""" @@ -87,6 +88,6 @@ class EntityMapStorage: self.store.async_delay_save(self._data_to_save, ENTITY_MAP_SAVE_DELAY) @callback - def _data_to_save(self) -> dict[str, Any]: + def _data_to_save(self) -> StorageLayout: """Return data of entity map to store in a file.""" - return {"pairings": self.storage_data} + return StorageLayout(pairings=self.storage_data) diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 374e69975ce..7c8594bdd90 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -7,7 +7,7 @@ import logging import os import ssl from tempfile import NamedTemporaryFile -from typing import Any, Final, Optional, TypedDict, Union, cast +from typing import Any, Final, TypedDict, Union, cast from aiohttp import web from aiohttp.typedefs import StrOrURL @@ -125,10 +125,10 @@ class ConfData(TypedDict, total=False): @bind_hass -async def async_get_last_config(hass: HomeAssistant) -> dict | None: +async def async_get_last_config(hass: HomeAssistant) -> dict[str, Any] | None: """Return the last known working config.""" - store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) - return cast(Optional[dict], await store.async_load()) + store = storage.Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) + return await store.async_load() class ApiConfig: @@ -475,7 +475,9 @@ async def start_http_server_and_save_config( await server.start() # If we are set up successful, we store the HTTP settings for safe mode. - store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) + store: storage.Store[dict[str, Any]] = storage.Store( + hass, STORAGE_VERSION, STORAGE_KEY + ) if CONF_TRUSTED_PROXIES in conf: conf[CONF_TRUSTED_PROXIES] = [ diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 18f68cc386f..7c6f445ce80 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -6,7 +6,7 @@ from datetime import timedelta from ipaddress import ip_address import logging import secrets -from typing import Final +from typing import Any, Final from aiohttp import hdrs from aiohttp.web import Application, Request, StreamResponse, middleware @@ -118,8 +118,8 @@ def async_user_not_allowed_do_auth( async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: """Create auth middleware for the app.""" - store = Store(hass, STORAGE_VERSION, STORAGE_KEY) - if (data := await store.async_load()) is None or not isinstance(data, dict): + store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) + if (data := await store.async_load()) is None: data = {} refresh_token = None diff --git a/homeassistant/components/mobile_app/__init__.py b/homeassistant/components/mobile_app/__init__.py index 7aac961042b..70c23da66e2 100644 --- a/homeassistant/components/mobile_app/__init__.py +++ b/homeassistant/components/mobile_app/__init__.py @@ -1,5 +1,6 @@ """Integrates Native Apps to Home Assistant.""" from contextlib import suppress +from typing import Any from homeassistant.components import cloud, notify as hass_notify from homeassistant.components.webhook import ( @@ -38,7 +39,7 @@ PLATFORMS = [Platform.SENSOR, Platform.BINARY_SENSOR, Platform.DEVICE_TRACKER] async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the mobile app component.""" - store = Store(hass, STORAGE_VERSION, STORAGE_KEY) + store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) if (app_config := await store.async_load()) is None or not isinstance( app_config, dict ): diff --git a/homeassistant/components/nest/media_source.py b/homeassistant/components/nest/media_source.py index 4614d4b1ed4..7b5b96b7145 100644 --- a/homeassistant/components/nest/media_source.py +++ b/homeassistant/components/nest/media_source.py @@ -22,6 +22,7 @@ from collections.abc import Mapping from dataclasses import dataclass import logging import os +from typing import Any from google_nest_sdm.camera_traits import CameraClipPreviewTrait, CameraEventImageTrait from google_nest_sdm.device import Device @@ -89,7 +90,7 @@ async def async_get_media_event_store( os.makedirs(media_path, exist_ok=True) await hass.async_add_executor_job(mkdir) - store = Store(hass, STORAGE_VERSION, STORAGE_KEY, private=True) + store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY, private=True) return NestEventMediaStore(hass, subscriber, store, media_path) @@ -119,7 +120,7 @@ class NestEventMediaStore(EventMediaStore): self, hass: HomeAssistant, subscriber: GoogleNestSubscriber, - store: Store, + store: Store[dict[str, Any]], media_path: str, ) -> None: """Initialize NestEventMediaStore.""" @@ -127,7 +128,7 @@ class NestEventMediaStore(EventMediaStore): self._subscriber = subscriber self._store = store self._media_path = media_path - self._data: dict | None = None + self._data: dict[str, Any] | None = None self._devices: Mapping[str, str] | None = {} async def async_load(self) -> dict | None: @@ -137,15 +138,9 @@ class NestEventMediaStore(EventMediaStore): if (data := await self._store.async_load()) is None: _LOGGER.debug("Loaded empty event store") self._data = {} - elif isinstance(data, dict): + else: _LOGGER.debug("Loaded event store with %d records", len(data)) self._data = data - else: - raise ValueError( - "Unexpected data in storage version={}, key={}".format( - STORAGE_VERSION, STORAGE_KEY - ) - ) return self._data async def async_save(self, data: dict) -> None: diff --git a/homeassistant/components/network/network.py b/homeassistant/components/network/network.py index b2caf6438bd..e9542ec2d54 100644 --- a/homeassistant/components/network/network.py +++ b/homeassistant/components/network/network.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import Any, cast +from typing import Any from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.singleton import singleton @@ -38,8 +38,10 @@ class Network: def __init__(self, hass: HomeAssistant) -> None: """Initialize the Network class.""" - self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True) - self._data: dict[str, Any] = {} + self._store = Store[dict[str, list[str]]]( + hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True + ) + self._data: dict[str, list[str]] = {} self.adapters: list[Adapter] = [] @property @@ -67,7 +69,7 @@ class Network: async def async_load(self) -> None: """Load config.""" if stored := await self._store.async_load(): - self._data = cast(dict, stored) + self._data = stored async def _async_save(self) -> None: """Save preferences.""" diff --git a/homeassistant/components/resolution_center/issue_registry.py b/homeassistant/components/resolution_center/issue_registry.py index d97ad73bbac..7d5bbb482ba 100644 --- a/homeassistant/components/resolution_center/issue_registry.py +++ b/homeassistant/components/resolution_center/issue_registry.py @@ -2,7 +2,7 @@ from __future__ import annotations import dataclasses -from typing import cast +from typing import Optional, cast from homeassistant.const import __version__ as ha_version from homeassistant.core import HomeAssistant, callback @@ -39,7 +39,9 @@ class IssueRegistry: """Initialize the issue registry.""" self.hass = hass self.issues: dict[tuple[str, str], IssueEntry] = {} - self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True) + self._store = Store[dict[str, list[dict[str, Optional[str]]]]]( + hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True + ) @callback def async_get_issue(self, domain: str, issue_id: str) -> IssueEntry | None: @@ -119,6 +121,7 @@ class IssueRegistry: if isinstance(data, dict): for issue in data["issues"]: + assert issue["domain"] and issue["issue_id"] issues[(issue["domain"], issue["issue_id"])] = IssueEntry( active=False, breaks_in_ha_version=None, diff --git a/homeassistant/components/smartthings/smartapp.py b/homeassistant/components/smartthings/smartapp.py index fbd63d41373..adf0426e9a2 100644 --- a/homeassistant/components/smartthings/smartapp.py +++ b/homeassistant/components/smartthings/smartapp.py @@ -3,6 +3,7 @@ import asyncio import functools import logging import secrets +from typing import Any from urllib.parse import urlparse from uuid import uuid4 @@ -211,8 +212,8 @@ async def setup_smartapp_endpoint(hass: HomeAssistant): return # Get/create config to store a unique id for this hass instance. - store = Store(hass, STORAGE_VERSION, STORAGE_KEY) - if not (config := await store.async_load()) or not isinstance(config, dict): + store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) + if not (config := await store.async_load()): # Create config config = { CONF_INSTANCE_ID: str(uuid4()), @@ -283,7 +284,7 @@ async def unload_smartapp_endpoint(hass: HomeAssistant): if cloudhook_url and cloud.async_is_logged_in(hass): await cloud.async_delete_cloudhook(hass, hass.data[DOMAIN][CONF_WEBHOOK_ID]) # Remove cloudhook from storage - store = Store(hass, STORAGE_VERSION, STORAGE_KEY) + store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) await store.async_save( { CONF_INSTANCE_ID: hass.data[DOMAIN][CONF_INSTANCE_ID], diff --git a/homeassistant/components/trace/__init__.py b/homeassistant/components/trace/__init__.py index 14783fd3f84..3761ff155b4 100644 --- a/homeassistant/components/trace/__init__.py +++ b/homeassistant/components/trace/__init__.py @@ -52,7 +52,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Initialize the trace integration.""" hass.data[DATA_TRACE] = {} websocket_api.async_setup(hass) - store = Store(hass, STORAGE_VERSION, STORAGE_KEY, encoder=ExtendedJSONEncoder) + store = Store[dict[str, list]]( + hass, STORAGE_VERSION, STORAGE_KEY, encoder=ExtendedJSONEncoder + ) hass.data[DATA_TRACE_STORE] = store async def _async_store_traces_at_stop(*_) -> None: diff --git a/homeassistant/components/zha/core/store.py b/homeassistant/components/zha/core/store.py index e58dcd46dba..0b7564fe815 100644 --- a/homeassistant/components/zha/core/store.py +++ b/homeassistant/components/zha/core/store.py @@ -40,7 +40,7 @@ class ZhaStorage: """Initialize the zha device storage.""" self.hass: HomeAssistant = hass self.devices: MutableMapping[str, ZhaDeviceEntry] = {} - self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY) + self._store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY) @callback def async_create_device(self, device: ZHADevice) -> ZhaDeviceEntry: @@ -94,7 +94,7 @@ class ZhaStorage: async def async_load(self) -> None: """Load the registry of zha device entries.""" - data = cast(dict[str, Any], await self._store.async_load()) + data = await self._store.async_load() devices: OrderedDict[str, ZhaDeviceEntry] = OrderedDict() diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index b25b62aa6e0..4b76f63681a 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -845,7 +845,9 @@ class ConfigEntries: self._hass_config = hass_config self._entries: dict[str, ConfigEntry] = {} self._domain_index: dict[str, list[str]] = {} - self._store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) + self._store = storage.Store[dict[str, list[dict[str, Any]]]]( + hass, STORAGE_VERSION, STORAGE_KEY + ) EntityRegistryDisabledHandler(hass).async_setup() @callback diff --git a/homeassistant/core.py b/homeassistant/core.py index b568ee72689..7b41fe476aa 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -1942,7 +1942,7 @@ class Config: # pylint: disable=import-outside-toplevel from .helpers.storage import Store - store = Store( + store = Store[dict[str, Any]]( self.hass, CORE_STORAGE_VERSION, CORE_STORAGE_KEY, @@ -1950,7 +1950,7 @@ class Config: atomic_writes=True, ) - if not (data := await store.async_load()) or not isinstance(data, dict): + if not (data := await store.async_load()): return # In 2021.9 we fixed validation to disallow a path (because that's never correct) @@ -1998,7 +1998,7 @@ class Config: "currency": self.currency, } - store = Store( + store: Store[dict[str, Any]] = Store( self.hass, CORE_STORAGE_VERSION, CORE_STORAGE_KEY, diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index e5d35ccbf44..aeb52e8faed 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import OrderedDict from collections.abc import Container, Iterable, MutableMapping -from typing import cast +from typing import Optional, cast import attr @@ -49,7 +49,9 @@ class AreaRegistry: """Initialize the area registry.""" self.hass = hass self.areas: MutableMapping[str, AreaEntry] = {} - self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True) + self._store = Store[dict[str, list[dict[str, Optional[str]]]]]( + hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True + ) self._normalized_name_area_idx: dict[str, str] = {} @callback @@ -176,8 +178,9 @@ class AreaRegistry: areas: MutableMapping[str, AreaEntry] = OrderedDict() - if isinstance(data, dict): + if data is not None: for area in data["areas"]: + assert area["name"] is not None and area["id"] is not None normalized_name = normalize_area_name(area["name"]) areas[area["id"]] = AreaEntry( name=area["name"], diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index ed3d5a7b06f..ca5e6e1aefa 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -164,7 +164,7 @@ def _async_get_device_id_from_index( return None -class DeviceRegistryStore(storage.Store): +class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): """Store entity registry data.""" async def _async_migrate_func( @@ -569,7 +569,6 @@ class DeviceRegistry: deleted_devices = OrderedDict() if data is not None: - data = cast("dict[str, Any]", data) for device in data["devices"]: devices[device["id"]] = DeviceEntry( area_id=device["area_id"], diff --git a/homeassistant/helpers/instance_id.py b/homeassistant/helpers/instance_id.py index 59a4cf39498..8561d10794c 100644 --- a/homeassistant/helpers/instance_id.py +++ b/homeassistant/helpers/instance_id.py @@ -16,7 +16,7 @@ LEGACY_UUID_FILE = ".uuid" @singleton.singleton(DATA_KEY) async def async_get(hass: HomeAssistant) -> str: """Get unique ID for the hass instance.""" - store = storage.Store(hass, DATA_VERSION, DATA_KEY, True) + store = storage.Store[dict[str, str]](hass, DATA_VERSION, DATA_KEY, True) data: dict[str, str] | None = await storage.async_migrator( hass, diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index b8262d3a533..4f2d1dd0503 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -139,7 +139,7 @@ class RestoreStateData: def __init__(self, hass: HomeAssistant) -> None: """Initialize the restore state data class.""" self.hass: HomeAssistant = hass - self.store: Store = Store( + self.store = Store[list[dict[str, Any]]]( hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder ) self.last_states: dict[str, StoredState] = {} diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index 554a88f4ad5..6819a1eb48b 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -2,14 +2,14 @@ from __future__ import annotations import asyncio -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from contextlib import suppress from copy import deepcopy import inspect from json import JSONEncoder import logging import os -from typing import Any +from typing import Any, Generic, TypeVar, Union from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback @@ -24,6 +24,8 @@ _LOGGER = logging.getLogger(__name__) STORAGE_SEMAPHORE = "storage_semaphore" +_T = TypeVar("_T", bound=Union[Mapping[str, Any], Sequence[Any]]) + @bind_hass async def async_migrator( @@ -66,7 +68,7 @@ async def async_migrator( @bind_hass -class Store: +class Store(Generic[_T]): """Class to help storing data.""" def __init__( @@ -90,7 +92,7 @@ class Store: self._unsub_delay_listener: CALLBACK_TYPE | None = None self._unsub_final_write_listener: CALLBACK_TYPE | None = None self._write_lock = asyncio.Lock() - self._load_task: asyncio.Future | None = None + self._load_task: asyncio.Future[_T | None] | None = None self._encoder = encoder self._atomic_writes = atomic_writes @@ -99,7 +101,7 @@ class Store: """Return the config path.""" return self.hass.config.path(STORAGE_DIR, self.key) - async def async_load(self) -> dict | list | None: + async def async_load(self) -> _T | None: """Load data. If the expected version and minor version do not match the given versions, the @@ -113,7 +115,7 @@ class Store: return await self._load_task - async def _async_load(self): + async def _async_load(self) -> _T | None: """Load the data and ensure the task is removed.""" if STORAGE_SEMAPHORE not in self.hass.data: self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY) @@ -178,7 +180,7 @@ class Store: return stored - async def async_save(self, data: dict | list) -> None: + async def async_save(self, data: _T) -> None: """Save data.""" self._data = { "version": self.version, @@ -196,7 +198,7 @@ class Store: @callback def async_delay_save( self, - data_func: Callable[[], dict | list], + data_func: Callable[[], _T], delay: float = 0, ) -> None: """Save data with an optional delay."""