From 3d700e2b71d3fb8884cc6721860bc6cc5c847c81 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 7 May 2024 10:53:13 +0200 Subject: [PATCH] Add HassDict implementation (#103844) --- homeassistant/config_entries.py | 4 +- homeassistant/core.py | 3 +- homeassistant/helpers/singleton.py | 13 ++- homeassistant/setup.py | 42 ++++--- homeassistant/util/hass_dict.py | 31 +++++ homeassistant/util/hass_dict.pyi | 176 +++++++++++++++++++++++++++++ tests/test_setup.py | 7 -- tests/util/test_hass_dict.py | 47 ++++++++ 8 files changed, 287 insertions(+), 36 deletions(-) create mode 100644 homeassistant/util/hass_dict.py create mode 100644 homeassistant/util/hass_dict.pyi create mode 100644 tests/util/test_hass_dict.py diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index cc3f45df2ef..40f55ec58f8 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -2035,9 +2035,7 @@ class ConfigEntries: Config entries which are created after Home Assistant is started can't be waited for, the function will just return if the config entry is loaded or not. """ - setup_done: dict[str, asyncio.Future[bool]] = self.hass.data.get( - DATA_SETUP_DONE, {} - ) + setup_done = self.hass.data.get(DATA_SETUP_DONE, {}) if setup_future := setup_done.get(entry.domain): await setup_future # The component was not loaded. diff --git a/homeassistant/core.py b/homeassistant/core.py index 613406340bf..5a75f0ce049 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -104,6 +104,7 @@ from .util.async_ import ( ) from .util.event_type import EventType from .util.executor import InterruptibleThreadPoolExecutor +from .util.hass_dict import HassDict from .util.json import JsonObjectType from .util.read_only_dict import ReadOnlyDict from .util.timeout import TimeoutManager @@ -406,7 +407,7 @@ class HomeAssistant: from . import loader # This is a dictionary that any component can store any data on. - self.data: dict[str, Any] = {} + self.data = HassDict() self.loop = asyncio.get_running_loop() self._tasks: set[asyncio.Future[Any]] = set() self._background_tasks: set[asyncio.Future[Any]] = set() diff --git a/homeassistant/helpers/singleton.py b/homeassistant/helpers/singleton.py index bf9b6019164..d11a4cc627c 100644 --- a/homeassistant/helpers/singleton.py +++ b/homeassistant/helpers/singleton.py @@ -5,17 +5,26 @@ from __future__ import annotations import asyncio from collections.abc import Callable import functools -from typing import Any, TypeVar, cast +from typing import Any, TypeVar, cast, overload from homeassistant.core import HomeAssistant from homeassistant.loader import bind_hass +from homeassistant.util.hass_dict import HassKey _T = TypeVar("_T") _FuncType = Callable[[HomeAssistant], _T] -def singleton(data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: +@overload +def singleton(data_key: HassKey[_T]) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ... + + +@overload +def singleton(data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ... + + +def singleton(data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]: """Decorate a function that should be called once per instance. Result will be cached and simultaneous calls will be handled. diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 8d7161d04e1..b3ce02905d3 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -33,6 +33,7 @@ from .helpers import singleton, translation from .helpers.issue_registry import IssueSeverity, async_create_issue from .helpers.typing import ConfigType from .util.async_ import create_eager_task +from .util.hass_dict import HassKey current_setup_group: contextvars.ContextVar[tuple[str, str | None] | None] = ( contextvars.ContextVar("current_setup_group", default=None) @@ -45,29 +46,32 @@ ATTR_COMPONENT: Final = "component" BASE_PLATFORMS = {platform.value for platform in Platform} -# DATA_SETUP is a dict[str, asyncio.Future[bool]], indicating domains which are currently +# DATA_SETUP is a dict, indicating domains which are currently # being setup or which failed to setup: # - Tasks are added to DATA_SETUP by `async_setup_component`, the key is the domain # being setup and the Task is the `_async_setup_component` helper. # - Tasks are removed from DATA_SETUP if setup was successful, that is, # the task returned True. -DATA_SETUP = "setup_tasks" +DATA_SETUP: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_tasks") -# DATA_SETUP_DONE is a dict [str, asyncio.Future[bool]], indicating components which -# will be setup: +# DATA_SETUP_DONE is a dict, indicating components which will be setup: # - Events are added to DATA_SETUP_DONE during bootstrap by # async_set_domains_to_be_loaded, the key is the domain which will be loaded. # - Events are set and removed from DATA_SETUP_DONE when async_setup_component # is finished, regardless of if the setup was successful or not. -DATA_SETUP_DONE = "setup_done" +DATA_SETUP_DONE: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_done") -# DATA_SETUP_STARTED is a dict [tuple[str, str | None], float], indicating when an attempt +# DATA_SETUP_STARTED is a dict, indicating when an attempt # to setup a component started. -DATA_SETUP_STARTED = "setup_started" +DATA_SETUP_STARTED: HassKey[dict[tuple[str, str | None], float]] = HassKey( + "setup_started" +) -# DATA_SETUP_TIME is a defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]] -# indicating how time was spent setting up a component and each group (config entry). -DATA_SETUP_TIME = "setup_time" +# DATA_SETUP_TIME is a defaultdict, indicating how time was spent +# setting up a component. +DATA_SETUP_TIME: HassKey[ + defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]] +] = HassKey("setup_time") DATA_DEPS_REQS = "deps_reqs_processed" @@ -126,9 +130,7 @@ def async_set_domains_to_be_loaded(hass: core.HomeAssistant, domains: set[str]) - Properly handle after_dependencies. - Keep track of domains which will load but have not yet finished loading """ - setup_done_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault( - DATA_SETUP_DONE, {} - ) + setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {}) setup_done_futures.update({domain: hass.loop.create_future() for domain in domains}) @@ -149,12 +151,8 @@ async def async_setup_component( if domain in hass.config.components: return True - setup_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault( - DATA_SETUP, {} - ) - setup_done_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault( - DATA_SETUP_DONE, {} - ) + setup_futures = hass.data.setdefault(DATA_SETUP, {}) + setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {}) if existing_setup_future := setup_futures.get(domain): return await existing_setup_future @@ -195,9 +193,7 @@ async def _async_process_dependencies( Returns a list of dependencies which failed to set up. """ - setup_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault( - DATA_SETUP, {} - ) + setup_futures = hass.data.setdefault(DATA_SETUP, {}) dependencies_tasks = { dep: setup_futures.get(dep) @@ -210,7 +206,7 @@ async def _async_process_dependencies( } after_dependencies_tasks: dict[str, asyncio.Future[bool]] = {} - to_be_loaded: dict[str, asyncio.Future[bool]] = hass.data.get(DATA_SETUP_DONE, {}) + to_be_loaded = hass.data.get(DATA_SETUP_DONE, {}) for dep in integration.after_dependencies: if ( dep not in dependencies_tasks diff --git a/homeassistant/util/hass_dict.py b/homeassistant/util/hass_dict.py new file mode 100644 index 00000000000..1d0e6844798 --- /dev/null +++ b/homeassistant/util/hass_dict.py @@ -0,0 +1,31 @@ +"""Implementation for HassDict and custom HassKey types. + +Custom for type checking. See stub file. +""" + +from __future__ import annotations + +from typing import Generic, TypeVar + +_T = TypeVar("_T") + + +class HassKey(str, Generic[_T]): + """Generic Hass key type. + + At runtime this is a generic subclass of str. + """ + + __slots__ = () + + +class HassEntryKey(str, Generic[_T]): + """Key type for integrations with config entries. + + At runtime this is a generic subclass of str. + """ + + __slots__ = () + + +HassDict = dict diff --git a/homeassistant/util/hass_dict.pyi b/homeassistant/util/hass_dict.pyi new file mode 100644 index 00000000000..0e8096eeeb6 --- /dev/null +++ b/homeassistant/util/hass_dict.pyi @@ -0,0 +1,176 @@ +"""Stub file for hass_dict. Provide overload for type checking.""" +# ruff: noqa: PYI021 # Allow docstrings + +from typing import Any, Generic, TypeVar, assert_type, overload + +__all__ = [ + "HassDict", + "HassEntryKey", + "HassKey", +] + +_T = TypeVar("_T") +_U = TypeVar("_U") + +class _Key(Generic[_T]): + """Base class for Hass key types. At runtime delegated to str.""" + + def __init__(self, value: str, /) -> None: ... + def __len__(self) -> int: ... + def __hash__(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__(self, index: int) -> str: ... + +class HassEntryKey(_Key[_T]): + """Key type for integrations with config entries.""" + +class HassKey(_Key[_T]): + """Generic Hass key type.""" + +class HassDict(dict[_Key[Any] | str, Any]): + """Custom dict type to provide better value type hints for Hass key types.""" + + @overload # type: ignore[override] + def __getitem__(self, key: HassEntryKey[_T], /) -> dict[str, _T]: ... + @overload + def __getitem__(self, key: HassKey[_T], /) -> _T: ... + @overload + def __getitem__(self, key: str, /) -> Any: ... + + # ------ + @overload # type: ignore[override] + def __setitem__(self, key: HassEntryKey[_T], value: dict[str, _T], /) -> None: ... + @overload + def __setitem__(self, key: HassKey[_T], value: _T, /) -> None: ... + @overload + def __setitem__(self, key: str, value: Any, /) -> None: ... + + # ------ + @overload # type: ignore[override] + def setdefault( + self, key: HassEntryKey[_T], default: dict[str, _T], / + ) -> dict[str, _T]: ... + @overload + def setdefault(self, key: HassKey[_T], default: _T, /) -> _T: ... + @overload + def setdefault(self, key: str, default: None = None, /) -> Any | None: ... + @overload + def setdefault(self, key: str, default: Any, /) -> Any: ... + + # ------ + @overload # type: ignore[override] + def get(self, key: HassEntryKey[_T], /) -> dict[str, _T] | None: ... + @overload + def get(self, key: HassEntryKey[_T], default: _U, /) -> dict[str, _T] | _U: ... + @overload + def get(self, key: HassKey[_T], /) -> _T | None: ... + @overload + def get(self, key: HassKey[_T], default: _U, /) -> _T | _U: ... + @overload + def get(self, key: str, /) -> Any | None: ... + @overload + def get(self, key: str, default: Any, /) -> Any: ... + + # ------ + @overload # type: ignore[override] + def pop(self, key: HassEntryKey[_T], /) -> dict[str, _T]: ... + @overload + def pop( + self, key: HassEntryKey[_T], default: dict[str, _T], / + ) -> dict[str, _T]: ... + @overload + def pop(self, key: HassEntryKey[_T], default: _U, /) -> dict[str, _T] | _U: ... + @overload + def pop(self, key: HassKey[_T], /) -> _T: ... + @overload + def pop(self, key: HassKey[_T], default: _T, /) -> _T: ... + @overload + def pop(self, key: HassKey[_T], default: _U, /) -> _T | _U: ... + @overload + def pop(self, key: str, /) -> Any: ... + @overload + def pop(self, key: str, default: _U, /) -> Any | _U: ... + +def _test_hass_dict_typing() -> None: # noqa: PYI048 + """Test HassDict overloads work as intended. + + This is tested during the mypy run. Do not move it to 'tests'! + """ + d = HassDict() + entry_key = HassEntryKey[int]("entry_key") + key = HassKey[int]("key") + key2 = HassKey[dict[int, bool]]("key2") + key3 = HassKey[set[str]]("key3") + other_key = "domain" + + # __getitem__ + assert_type(d[entry_key], dict[str, int]) + assert_type(d[entry_key]["entry_id"], int) + assert_type(d[key], int) + assert_type(d[key2], dict[int, bool]) + + # __setitem__ + d[entry_key] = {} + d[entry_key] = 2 # type: ignore[call-overload] + d[entry_key]["entry_id"] = 2 + d[entry_key]["entry_id"] = "Hello World" # type: ignore[assignment] + d[key] = 2 + d[key] = "Hello World" # type: ignore[misc] + d[key] = {} # type: ignore[misc] + d[key2] = {} + d[key2] = 2 # type: ignore[misc] + d[key3] = set() + d[key3] = 2 # type: ignore[misc] + d[other_key] = 2 + d[other_key] = "Hello World" + + # get + assert_type(d.get(entry_key), dict[str, int] | None) + assert_type(d.get(entry_key, True), dict[str, int] | bool) + assert_type(d.get(key), int | None) + assert_type(d.get(key, True), int | bool) + assert_type(d.get(key2), dict[int, bool] | None) + assert_type(d.get(key2, {}), dict[int, bool]) + assert_type(d.get(key3), set[str] | None) + assert_type(d.get(key3, set()), set[str]) + assert_type(d.get(other_key), Any | None) + assert_type(d.get(other_key, True), Any) + assert_type(d.get(other_key, {})["id"], Any) + + # setdefault + assert_type(d.setdefault(entry_key, {}), dict[str, int]) + assert_type(d.setdefault(entry_key, {})["entry_id"], int) + assert_type(d.setdefault(key, 2), int) + assert_type(d.setdefault(key2, {}), dict[int, bool]) + assert_type(d.setdefault(key2, {})[2], bool) + assert_type(d.setdefault(key3, set()), set[str]) + assert_type(d.setdefault(other_key, 2), Any) + assert_type(d.setdefault(other_key), Any | None) + d.setdefault(entry_key, {})["entry_id"] = 2 + d.setdefault(entry_key, {})["entry_id"] = "Hello World" # type: ignore[assignment] + d.setdefault(key, 2) + d.setdefault(key, "Error") # type: ignore[misc] + d.setdefault(key2, {})[2] = True + d.setdefault(key2, {})[2] = "Error" # type: ignore[assignment] + d.setdefault(key3, set()).add("Hello World") + d.setdefault(key3, set()).add(2) # type: ignore[arg-type] + d.setdefault(other_key, {})["id"] = 2 + d.setdefault(other_key, {})["id"] = "Hello World" + d.setdefault(entry_key) # type: ignore[call-overload] + d.setdefault(key) # type: ignore[call-overload] + d.setdefault(key2) # type: ignore[call-overload] + + # pop + assert_type(d.pop(entry_key), dict[str, int]) + assert_type(d.pop(entry_key, {}), dict[str, int]) + assert_type(d.pop(entry_key, 2), dict[str, int] | int) + assert_type(d.pop(key), int) + assert_type(d.pop(key, 2), int) + assert_type(d.pop(key, "Hello World"), int | str) + assert_type(d.pop(key2), dict[int, bool]) + assert_type(d.pop(key2, {}), dict[int, bool]) + assert_type(d.pop(key2, 2), dict[int, bool] | int) + assert_type(d.pop(key3), set[str]) + assert_type(d.pop(key3, set()), set[str]) + assert_type(d.pop(other_key), Any) + assert_type(d.pop(other_key, True), Any | bool) diff --git a/tests/test_setup.py b/tests/test_setup.py index 65472643adb..50dd8bba6c5 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -739,7 +739,6 @@ async def test_integration_only_setup_entry(hass: HomeAssistant) -> None: async def test_async_start_setup_running(hass: HomeAssistant) -> None: """Test setup started context manager does nothing when running.""" assert hass.state is CoreState.running - setup_started: dict[tuple[str, str | None], float] setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) with setup.async_start_setup( @@ -753,7 +752,6 @@ async def test_async_start_setup_config_entry( ) -> None: """Test setup started keeps track of setup times with a config entry.""" hass.set_state(CoreState.not_running) - setup_started: dict[tuple[str, str | None], float] setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) @@ -864,7 +862,6 @@ async def test_async_start_setup_config_entry_late_platform( ) -> None: """Test setup started tracks config entry time with a late platform load.""" hass.set_state(CoreState.not_running) - setup_started: dict[tuple[str, str | None], float] setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) @@ -919,7 +916,6 @@ async def test_async_start_setup_config_entry_platform_wait( ) -> None: """Test setup started tracks wait time when a platform loads inside of config entry setup.""" hass.set_state(CoreState.not_running) - setup_started: dict[tuple[str, str | None], float] setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) @@ -962,7 +958,6 @@ async def test_async_start_setup_config_entry_platform_wait( async def test_async_start_setup_top_level_yaml(hass: HomeAssistant) -> None: """Test setup started context manager keeps track of setup times with modern yaml.""" hass.set_state(CoreState.not_running) - setup_started: dict[tuple[str, str | None], float] setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) @@ -979,7 +974,6 @@ async def test_async_start_setup_top_level_yaml(hass: HomeAssistant) -> None: async def test_async_start_setup_platform_integration(hass: HomeAssistant) -> None: """Test setup started keeps track of setup times a platform integration.""" hass.set_state(CoreState.not_running) - setup_started: dict[tuple[str, str | None], float] setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) @@ -1014,7 +1008,6 @@ async def test_async_start_setup_legacy_platform_integration( ) -> None: """Test setup started keeps track of setup times for a legacy platform integration.""" hass.set_state(CoreState.not_running) - setup_started: dict[tuple[str, str | None], float] setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {}) setup_time = setup._setup_times(hass) diff --git a/tests/util/test_hass_dict.py b/tests/util/test_hass_dict.py new file mode 100644 index 00000000000..36e427af41f --- /dev/null +++ b/tests/util/test_hass_dict.py @@ -0,0 +1,47 @@ +"""Test HassDict and custom HassKey types.""" + +from homeassistant.util.hass_dict import HassDict, HassEntryKey, HassKey + + +def test_key_comparison() -> None: + """Test key comparison with itself and string keys.""" + + str_key = "custom-key" + key = HassKey[int](str_key) + other_key = HassKey[str]("other-key") + + entry_key = HassEntryKey[int](str_key) + other_entry_key = HassEntryKey[str]("other-key") + + assert key == str_key + assert key != other_key + assert key != 2 + + assert entry_key == str_key + assert entry_key != other_entry_key + assert entry_key != 2 + + # Only compare name attribute, HassKey() == HassEntryKey() + assert key == entry_key + + +def test_hass_dict_access() -> None: + """Test keys with the same name all access the same value in HassDict.""" + + data = HassDict() + str_key = "custom-key" + key = HassKey[int](str_key) + other_key = HassKey[str]("other-key") + + entry_key = HassEntryKey[int](str_key) + other_entry_key = HassEntryKey[str]("other-key") + + data[str_key] = True + assert data.get(key) is True + assert data.get(other_key) is None + + assert data.get(entry_key) is True # type: ignore[comparison-overlap] + assert data.get(other_entry_key) is None + + data[key] = False + assert data[str_key] is False