diff --git a/.coveragerc b/.coveragerc index 98992bed247..5c5d88b42c4 100644 --- a/.coveragerc +++ b/.coveragerc @@ -331,6 +331,7 @@ omit = homeassistant/components/esphome/camera.py homeassistant/components/esphome/climate.py homeassistant/components/esphome/cover.py + homeassistant/components/esphome/domain_data.py homeassistant/components/esphome/entry_data.py homeassistant/components/esphome/fan.py homeassistant/components/esphome/light.py diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index 07b6d3071f6..acf8d33b6e0 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Callable -from dataclasses import dataclass, field import functools import logging import math @@ -47,70 +46,18 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import async_track_state_change_event -from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.service import async_set_service_schema -from homeassistant.helpers.storage import Store from homeassistant.helpers.template import Template from .bluetooth import async_connect_scanner +from .domain_data import DOMAIN, DomainData # Import config flow so that it's added to the registry from .entry_data import RuntimeEntryData -DOMAIN = "esphome" CONF_NOISE_PSK = "noise_psk" _LOGGER = logging.getLogger(__name__) _R = TypeVar("_R") -_DomainDataSelfT = TypeVar("_DomainDataSelfT", bound="DomainData") - -STORAGE_VERSION = 1 - - -@dataclass -class DomainData: - """Define a class that stores global esphome data in hass.data[DOMAIN].""" - - _entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict) - _stores: dict[str, Store] = field(default_factory=dict) - - def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData: - """Return the runtime entry data associated with this config entry. - - Raises KeyError if the entry isn't loaded yet. - """ - return self._entry_datas[entry.entry_id] - - def set_entry_data(self, entry: ConfigEntry, entry_data: RuntimeEntryData) -> None: - """Set the runtime entry data associated with this config entry.""" - if entry.entry_id in self._entry_datas: - raise ValueError("Entry data for this entry is already set") - self._entry_datas[entry.entry_id] = entry_data - - def pop_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData: - """Pop the runtime entry data instance associated with this config entry.""" - return self._entry_datas.pop(entry.entry_id) - - def is_entry_loaded(self, entry: ConfigEntry) -> bool: - """Check whether the given entry is loaded.""" - return entry.entry_id in self._entry_datas - - def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store: - """Get or create a Store instance for the given config entry.""" - return self._stores.setdefault( - entry.entry_id, - Store( - hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder - ), - ) - - @classmethod - def get(cls: type[_DomainDataSelfT], hass: HomeAssistant) -> _DomainDataSelfT: - """Get the global DomainData instance stored in hass.data.""" - # Don't use setdefault - this is a hot code path - if DOMAIN in hass.data: - return cast(_DomainDataSelfT, hass.data[DOMAIN]) - ret = hass.data[DOMAIN] = cls() - return ret async def async_setup_entry( # noqa: C901 diff --git a/homeassistant/components/esphome/domain_data.py b/homeassistant/components/esphome/domain_data.py new file mode 100644 index 00000000000..9fabcf17d78 --- /dev/null +++ b/homeassistant/components/esphome/domain_data.py @@ -0,0 +1,68 @@ +"""Support for esphome domain data.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TypeVar, cast + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.storage import Store + +from .entry_data import RuntimeEntryData + +STORAGE_VERSION = 1 +DOMAIN = "esphome" +_DomainDataSelfT = TypeVar("_DomainDataSelfT", bound="DomainData") + + +@dataclass +class DomainData: + """Define a class that stores global esphome data in hass.data[DOMAIN].""" + + _entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict) + _stores: dict[str, Store] = field(default_factory=dict) + _entry_by_unique_id: dict[str, ConfigEntry] = field(default_factory=dict) + + def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData: + """Return the runtime entry data associated with this config entry. + + Raises KeyError if the entry isn't loaded yet. + """ + return self._entry_datas[entry.entry_id] + + def set_entry_data(self, entry: ConfigEntry, entry_data: RuntimeEntryData) -> None: + """Set the runtime entry data associated with this config entry.""" + if entry.entry_id in self._entry_datas: + raise ValueError("Entry data for this entry is already set") + self._entry_datas[entry.entry_id] = entry_data + if entry.unique_id: + self._entry_by_unique_id[entry.unique_id] = entry + + def pop_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData: + """Pop the runtime entry data instance associated with this config entry.""" + if entry.unique_id: + del self._entry_by_unique_id[entry.unique_id] + return self._entry_datas.pop(entry.entry_id) + + def is_entry_loaded(self, entry: ConfigEntry) -> bool: + """Check whether the given entry is loaded.""" + return entry.entry_id in self._entry_datas + + def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store: + """Get or create a Store instance for the given config entry.""" + return self._stores.setdefault( + entry.entry_id, + Store( + hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder + ), + ) + + @classmethod + def get(cls: type[_DomainDataSelfT], hass: HomeAssistant) -> _DomainDataSelfT: + """Get the global DomainData instance stored in hass.data.""" + # Don't use setdefault - this is a hot code path + if DOMAIN in hass.data: + return cast(_DomainDataSelfT, hass.data[DOMAIN]) + ret = hass.data[DOMAIN] = cls() + return ret