mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 11:47:06 +00:00
Support restoring SensorEntity native_value (#66068)
This commit is contained in:
parent
f8a84f0101
commit
009b31941a
@ -52,7 +52,9 @@ from homeassistant.helpers.config_validation import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
from homeassistant.helpers.entity import Entity, EntityDescription
|
from homeassistant.helpers.entity import Entity, EntityDescription
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity
|
||||||
from homeassistant.helpers.typing import ConfigType, StateType
|
from homeassistant.helpers.typing import ConfigType, StateType
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from .const import CONF_STATE_CLASS # noqa: F401
|
from .const import CONF_STATE_CLASS # noqa: F401
|
||||||
|
|
||||||
@ -447,3 +449,62 @@ class SensorEntity(Entity):
|
|||||||
return f"<Entity {self.name}>"
|
return f"<Entity {self.name}>"
|
||||||
|
|
||||||
return super().__repr__()
|
return super().__repr__()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SensorExtraStoredData(ExtraStoredData):
|
||||||
|
"""Object to hold extra stored data."""
|
||||||
|
|
||||||
|
native_value: StateType | date | datetime
|
||||||
|
native_unit_of_measurement: str | None
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return a dict representation of the sensor data."""
|
||||||
|
native_value: StateType | date | datetime | dict[str, str] = self.native_value
|
||||||
|
if isinstance(native_value, (date, datetime)):
|
||||||
|
native_value = {
|
||||||
|
"__type": str(type(native_value)),
|
||||||
|
"isoformat": native_value.isoformat(),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"native_value": native_value,
|
||||||
|
"native_unit_of_measurement": self.native_unit_of_measurement,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, restored: dict[str, Any]) -> SensorExtraStoredData | None:
|
||||||
|
"""Initialize a stored sensor state from a dict."""
|
||||||
|
try:
|
||||||
|
native_value = restored["native_value"]
|
||||||
|
native_unit_of_measurement = restored["native_unit_of_measurement"]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
type_ = native_value["__type"]
|
||||||
|
if type_ == "<class 'datetime.datetime'>":
|
||||||
|
native_value = dt_util.parse_datetime(native_value["isoformat"])
|
||||||
|
elif type_ == "<class 'datetime.date'>":
|
||||||
|
native_value = dt_util.parse_date(native_value["isoformat"])
|
||||||
|
except TypeError:
|
||||||
|
# native_value is not a dict
|
||||||
|
pass
|
||||||
|
except KeyError:
|
||||||
|
# native_value is a dict, but does not have all values
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cls(native_value, native_unit_of_measurement)
|
||||||
|
|
||||||
|
|
||||||
|
class RestoreSensor(SensorEntity, RestoreEntity):
|
||||||
|
"""Mixin class for restoring previous sensor state."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_restore_state_data(self) -> SensorExtraStoredData:
|
||||||
|
"""Return sensor specific state data to be restored."""
|
||||||
|
return SensorExtraStoredData(self.native_value, self.native_unit_of_measurement)
|
||||||
|
|
||||||
|
async def async_get_last_sensor_data(self) -> SensorExtraStoredData | None:
|
||||||
|
"""Restore native_value and native_unit_of_measurement."""
|
||||||
|
if (restored_last_extra_data := await self.async_get_last_extra_data()) is None:
|
||||||
|
return None
|
||||||
|
return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict())
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Support for restoring entity states on startup."""
|
"""Support for restoring entity states on startup."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
@ -34,27 +35,65 @@ STATE_EXPIRATION = timedelta(days=7)
|
|||||||
_StoredStateT = TypeVar("_StoredStateT", bound="StoredState")
|
_StoredStateT = TypeVar("_StoredStateT", bound="StoredState")
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraStoredData:
|
||||||
|
"""Object to hold extra stored data."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return a dict representation of the extra data.
|
||||||
|
|
||||||
|
Must be serializable by Home Assistant's JSONEncoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RestoredExtraData(ExtraStoredData):
|
||||||
|
"""Object to hold extra stored data loaded from storage."""
|
||||||
|
|
||||||
|
def __init__(self, json_dict: dict[str, Any]) -> None:
|
||||||
|
"""Object to hold extra stored data."""
|
||||||
|
self.json_dict = json_dict
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return a dict representation of the extra data."""
|
||||||
|
return self.json_dict
|
||||||
|
|
||||||
|
|
||||||
class StoredState:
|
class StoredState:
|
||||||
"""Object to represent a stored state."""
|
"""Object to represent a stored state."""
|
||||||
|
|
||||||
def __init__(self, state: State, last_seen: datetime) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
state: State,
|
||||||
|
extra_data: ExtraStoredData | None,
|
||||||
|
last_seen: datetime,
|
||||||
|
) -> None:
|
||||||
"""Initialize a new stored state."""
|
"""Initialize a new stored state."""
|
||||||
self.state = state
|
self.extra_data = extra_data
|
||||||
self.last_seen = last_seen
|
self.last_seen = last_seen
|
||||||
|
self.state = state
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
"""Return a dict representation of the stored state."""
|
"""Return a dict representation of the stored state."""
|
||||||
return {"state": self.state.as_dict(), "last_seen": self.last_seen}
|
result = {
|
||||||
|
"state": self.state.as_dict(),
|
||||||
|
"extra_data": self.extra_data.as_dict() if self.extra_data else None,
|
||||||
|
"last_seen": self.last_seen,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT:
|
def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT:
|
||||||
"""Initialize a stored state from a dict."""
|
"""Initialize a stored state from a dict."""
|
||||||
|
extra_data_dict = json_dict.get("extra_data")
|
||||||
|
extra_data = RestoredExtraData(extra_data_dict) if extra_data_dict else None
|
||||||
last_seen = json_dict["last_seen"]
|
last_seen = json_dict["last_seen"]
|
||||||
|
|
||||||
if isinstance(last_seen, str):
|
if isinstance(last_seen, str):
|
||||||
last_seen = dt_util.parse_datetime(last_seen)
|
last_seen = dt_util.parse_datetime(last_seen)
|
||||||
|
|
||||||
return cls(cast(State, State.from_dict(json_dict["state"])), last_seen)
|
return cls(
|
||||||
|
cast(State, State.from_dict(json_dict["state"])), extra_data, last_seen
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RestoreStateData:
|
class RestoreStateData:
|
||||||
@ -104,7 +143,7 @@ class RestoreStateData:
|
|||||||
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
|
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
|
||||||
)
|
)
|
||||||
self.last_states: dict[str, StoredState] = {}
|
self.last_states: dict[str, StoredState] = {}
|
||||||
self.entity_ids: set[str] = set()
|
self.entities: dict[str, RestoreEntity] = {}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_stored_states(self) -> list[StoredState]:
|
def async_get_stored_states(self) -> list[StoredState]:
|
||||||
@ -125,9 +164,11 @@ class RestoreStateData:
|
|||||||
|
|
||||||
# Start with the currently registered states
|
# Start with the currently registered states
|
||||||
stored_states = [
|
stored_states = [
|
||||||
StoredState(state, now)
|
StoredState(
|
||||||
|
state, self.entities[state.entity_id].extra_restore_state_data, now
|
||||||
|
)
|
||||||
for state in all_states
|
for state in all_states
|
||||||
if state.entity_id in self.entity_ids and
|
if state.entity_id in self.entities and
|
||||||
# Ignore all states that are entity registry placeholders
|
# Ignore all states that are entity registry placeholders
|
||||||
not state.attributes.get(ATTR_RESTORED)
|
not state.attributes.get(ATTR_RESTORED)
|
||||||
]
|
]
|
||||||
@ -188,12 +229,14 @@ class RestoreStateData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_restore_entity_added(self, entity_id: str) -> None:
|
def async_restore_entity_added(self, entity: RestoreEntity) -> None:
|
||||||
"""Store this entity's state when hass is shutdown."""
|
"""Store this entity's state when hass is shutdown."""
|
||||||
self.entity_ids.add(entity_id)
|
self.entities[entity.entity_id] = entity
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_restore_entity_removed(self, entity_id: str) -> None:
|
def async_restore_entity_removed(
|
||||||
|
self, entity_id: str, extra_data: ExtraStoredData | None
|
||||||
|
) -> None:
|
||||||
"""Unregister this entity from saving state."""
|
"""Unregister this entity from saving state."""
|
||||||
# When an entity is being removed from hass, store its last state. This
|
# When an entity is being removed from hass, store its last state. This
|
||||||
# allows us to support state restoration if the entity is removed, then
|
# allows us to support state restoration if the entity is removed, then
|
||||||
@ -204,9 +247,11 @@ class RestoreStateData:
|
|||||||
if state is not None:
|
if state is not None:
|
||||||
state = State.from_dict(_encode_complex(state.as_dict()))
|
state = State.from_dict(_encode_complex(state.as_dict()))
|
||||||
if state is not None:
|
if state is not None:
|
||||||
self.last_states[entity_id] = StoredState(state, dt_util.utcnow())
|
self.last_states[entity_id] = StoredState(
|
||||||
|
state, extra_data, dt_util.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
self.entity_ids.remove(entity_id)
|
self.entities.pop(entity_id)
|
||||||
|
|
||||||
|
|
||||||
def _encode(value: Any) -> Any:
|
def _encode(value: Any) -> Any:
|
||||||
@ -244,7 +289,7 @@ class RestoreEntity(Entity):
|
|||||||
super().async_internal_added_to_hass(),
|
super().async_internal_added_to_hass(),
|
||||||
RestoreStateData.async_get_instance(self.hass),
|
RestoreStateData.async_get_instance(self.hass),
|
||||||
)
|
)
|
||||||
data.async_restore_entity_added(self.entity_id)
|
data.async_restore_entity_added(self)
|
||||||
|
|
||||||
async def async_internal_will_remove_from_hass(self) -> None:
|
async def async_internal_will_remove_from_hass(self) -> None:
|
||||||
"""Run when entity will be removed from hass."""
|
"""Run when entity will be removed from hass."""
|
||||||
@ -252,10 +297,10 @@ class RestoreEntity(Entity):
|
|||||||
super().async_internal_will_remove_from_hass(),
|
super().async_internal_will_remove_from_hass(),
|
||||||
RestoreStateData.async_get_instance(self.hass),
|
RestoreStateData.async_get_instance(self.hass),
|
||||||
)
|
)
|
||||||
data.async_restore_entity_removed(self.entity_id)
|
data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data)
|
||||||
|
|
||||||
async def async_get_last_state(self) -> State | None:
|
async def _async_get_restored_data(self) -> StoredState | None:
|
||||||
"""Get the entity state from the previous run."""
|
"""Get data stored for an entity, if any."""
|
||||||
if self.hass is None or self.entity_id is None:
|
if self.hass is None or self.entity_id is None:
|
||||||
# Return None if this entity isn't added to hass yet
|
# Return None if this entity isn't added to hass yet
|
||||||
_LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable]
|
_LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable]
|
||||||
@ -265,4 +310,24 @@ class RestoreEntity(Entity):
|
|||||||
)
|
)
|
||||||
if self.entity_id not in data.last_states:
|
if self.entity_id not in data.last_states:
|
||||||
return None
|
return None
|
||||||
return data.last_states[self.entity_id].state
|
return data.last_states[self.entity_id]
|
||||||
|
|
||||||
|
async def async_get_last_state(self) -> State | None:
|
||||||
|
"""Get the entity state from the previous run."""
|
||||||
|
if (stored_state := await self._async_get_restored_data()) is None:
|
||||||
|
return None
|
||||||
|
return stored_state.state
|
||||||
|
|
||||||
|
async def async_get_last_extra_data(self) -> ExtraStoredData | None:
|
||||||
|
"""Get the entity specific state data from the previous run."""
|
||||||
|
if (stored_state := await self._async_get_restored_data()) is None:
|
||||||
|
return None
|
||||||
|
return stored_state.extra_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_restore_state_data(self) -> ExtraStoredData | None:
|
||||||
|
"""Return entity specific state data to be restored.
|
||||||
|
|
||||||
|
Implemented by platform classes.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
@ -44,7 +44,7 @@ from homeassistant.const import (
|
|||||||
STATE_OFF,
|
STATE_OFF,
|
||||||
STATE_ON,
|
STATE_ON,
|
||||||
)
|
)
|
||||||
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, State
|
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry,
|
area_registry,
|
||||||
device_registry,
|
device_registry,
|
||||||
@ -937,8 +937,33 @@ def mock_restore_cache(hass, states):
|
|||||||
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
last_states[state.entity_id] = restore_state.StoredState(
|
last_states[state.entity_id] = restore_state.StoredState.from_dict(
|
||||||
State.from_dict(restored_state), now
|
{"state": restored_state, "last_seen": now}
|
||||||
|
)
|
||||||
|
data.last_states = last_states
|
||||||
|
_LOGGER.debug("Restore cache: %s", data.last_states)
|
||||||
|
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
|
||||||
|
|
||||||
|
hass.data[key] = data
|
||||||
|
|
||||||
|
|
||||||
|
def mock_restore_cache_with_extra_data(hass, states):
|
||||||
|
"""Mock the DATA_RESTORE_CACHE."""
|
||||||
|
key = restore_state.DATA_RESTORE_STATE_TASK
|
||||||
|
data = restore_state.RestoreStateData(hass)
|
||||||
|
now = date_util.utcnow()
|
||||||
|
|
||||||
|
last_states = {}
|
||||||
|
for state, extra_data in states:
|
||||||
|
restored_state = state.as_dict()
|
||||||
|
restored_state = {
|
||||||
|
**restored_state,
|
||||||
|
"attributes": json.loads(
|
||||||
|
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
last_states[state.entity_id] = restore_state.StoredState.from_dict(
|
||||||
|
{"state": restored_state, "extra_data": extra_data, "last_seen": now}
|
||||||
)
|
)
|
||||||
data.last_states = last_states
|
data.last_states = last_states
|
||||||
_LOGGER.debug("Restore cache: %s", data.last_states)
|
_LOGGER.debug("Restore cache: %s", data.last_states)
|
||||||
|
@ -11,10 +11,14 @@ from homeassistant.const import (
|
|||||||
TEMP_CELSIUS,
|
TEMP_CELSIUS,
|
||||||
TEMP_FAHRENHEIT,
|
TEMP_FAHRENHEIT,
|
||||||
)
|
)
|
||||||
|
from homeassistant.core import State
|
||||||
|
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
|
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
|
||||||
|
|
||||||
|
from tests.common import mock_restore_cache_with_extra_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"unit_system,native_unit,state_unit,native_value,state_value",
|
"unit_system,native_unit,state_unit,native_value,state_value",
|
||||||
@ -210,3 +214,131 @@ async def test_reject_timezoneless_datetime_str(
|
|||||||
"Invalid datetime: sensor.test provides state '2017-12-19 18:29:42', "
|
"Invalid datetime: sensor.test provides state '2017-12-19 18:29:42', "
|
||||||
"which is missing timezone information"
|
"which is missing timezone information"
|
||||||
) in caplog.text
|
) in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
RESTORE_DATA = {
|
||||||
|
"str": {"native_unit_of_measurement": "°F", "native_value": "abc123"},
|
||||||
|
"int": {"native_unit_of_measurement": "°F", "native_value": 123},
|
||||||
|
"float": {"native_unit_of_measurement": "°F", "native_value": 123.0},
|
||||||
|
"date": {
|
||||||
|
"native_unit_of_measurement": "°F",
|
||||||
|
"native_value": {
|
||||||
|
"__type": "<class 'datetime.date'>",
|
||||||
|
"isoformat": date(2020, 2, 8).isoformat(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"datetime": {
|
||||||
|
"native_unit_of_measurement": "°F",
|
||||||
|
"native_value": {
|
||||||
|
"__type": "<class 'datetime.datetime'>",
|
||||||
|
"isoformat": datetime(2020, 2, 8, 15, tzinfo=timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# None | str | int | float | date | datetime:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"native_value, native_value_type, expected_extra_data, device_class",
|
||||||
|
[
|
||||||
|
("abc123", str, RESTORE_DATA["str"], None),
|
||||||
|
(123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE),
|
||||||
|
(123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE),
|
||||||
|
(date(2020, 2, 8), dict, RESTORE_DATA["date"], SensorDeviceClass.DATE),
|
||||||
|
(
|
||||||
|
datetime(2020, 2, 8, 15, tzinfo=timezone.utc),
|
||||||
|
dict,
|
||||||
|
RESTORE_DATA["datetime"],
|
||||||
|
SensorDeviceClass.TIMESTAMP,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_restore_sensor_save_state(
|
||||||
|
hass,
|
||||||
|
enable_custom_integrations,
|
||||||
|
hass_storage,
|
||||||
|
native_value,
|
||||||
|
native_value_type,
|
||||||
|
expected_extra_data,
|
||||||
|
device_class,
|
||||||
|
):
|
||||||
|
"""Test RestoreSensor."""
|
||||||
|
platform = getattr(hass.components, "test.sensor")
|
||||||
|
platform.init(empty=True)
|
||||||
|
platform.ENTITIES["0"] = platform.MockRestoreSensor(
|
||||||
|
name="Test",
|
||||||
|
native_value=native_value,
|
||||||
|
native_unit_of_measurement=TEMP_FAHRENHEIT,
|
||||||
|
device_class=device_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity0 = platform.ENTITIES["0"]
|
||||||
|
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Trigger saving state
|
||||||
|
await hass.async_stop()
|
||||||
|
|
||||||
|
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
|
||||||
|
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]
|
||||||
|
assert state["entity_id"] == entity0.entity_id
|
||||||
|
extra_data = hass_storage[RESTORE_STATE_KEY]["data"][0]["extra_data"]
|
||||||
|
assert extra_data == expected_extra_data
|
||||||
|
assert type(extra_data["native_value"]) == native_value_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"native_value, native_value_type, extra_data, device_class, uom",
|
||||||
|
[
|
||||||
|
("abc123", str, RESTORE_DATA["str"], None, "°F"),
|
||||||
|
(123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE, "°F"),
|
||||||
|
(123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE, "°F"),
|
||||||
|
(date(2020, 2, 8), date, RESTORE_DATA["date"], SensorDeviceClass.DATE, "°F"),
|
||||||
|
(
|
||||||
|
datetime(2020, 2, 8, 15, tzinfo=timezone.utc),
|
||||||
|
datetime,
|
||||||
|
RESTORE_DATA["datetime"],
|
||||||
|
SensorDeviceClass.TIMESTAMP,
|
||||||
|
"°F",
|
||||||
|
),
|
||||||
|
(None, type(None), None, None, None),
|
||||||
|
(None, type(None), {}, None, None),
|
||||||
|
(None, type(None), {"beer": 123}, None, None),
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
type(None),
|
||||||
|
{"native_unit_of_measurement": "°F", "native_value": {}},
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_restore_sensor_restore_state(
|
||||||
|
hass,
|
||||||
|
enable_custom_integrations,
|
||||||
|
hass_storage,
|
||||||
|
native_value,
|
||||||
|
native_value_type,
|
||||||
|
extra_data,
|
||||||
|
device_class,
|
||||||
|
uom,
|
||||||
|
):
|
||||||
|
"""Test RestoreSensor."""
|
||||||
|
mock_restore_cache_with_extra_data(hass, ((State("sensor.test", ""), extra_data),))
|
||||||
|
|
||||||
|
platform = getattr(hass.components, "test.sensor")
|
||||||
|
platform.init(empty=True)
|
||||||
|
platform.ENTITIES["0"] = platform.MockRestoreSensor(
|
||||||
|
name="Test",
|
||||||
|
device_class=device_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity0 = platform.ENTITIES["0"]
|
||||||
|
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.states.get(entity0.entity_id)
|
||||||
|
|
||||||
|
assert entity0.native_value == native_value
|
||||||
|
assert type(entity0.native_value) == native_value_type
|
||||||
|
assert entity0.native_unit_of_measurement == uom
|
||||||
|
@ -22,9 +22,9 @@ async def test_caching_data(hass):
|
|||||||
"""Test that we cache data."""
|
"""Test that we cache data."""
|
||||||
now = dt_util.utcnow()
|
now = dt_util.utcnow()
|
||||||
stored_states = [
|
stored_states = [
|
||||||
StoredState(State("input_boolean.b0", "on"), now),
|
StoredState(State("input_boolean.b0", "on"), None, now),
|
||||||
StoredState(State("input_boolean.b1", "on"), now),
|
StoredState(State("input_boolean.b1", "on"), None, now),
|
||||||
StoredState(State("input_boolean.b2", "on"), now),
|
StoredState(State("input_boolean.b2", "on"), None, now),
|
||||||
]
|
]
|
||||||
|
|
||||||
data = await RestoreStateData.async_get_instance(hass)
|
data = await RestoreStateData.async_get_instance(hass)
|
||||||
@ -160,9 +160,9 @@ async def test_hass_starting(hass):
|
|||||||
|
|
||||||
now = dt_util.utcnow()
|
now = dt_util.utcnow()
|
||||||
stored_states = [
|
stored_states = [
|
||||||
StoredState(State("input_boolean.b0", "on"), now),
|
StoredState(State("input_boolean.b0", "on"), None, now),
|
||||||
StoredState(State("input_boolean.b1", "on"), now),
|
StoredState(State("input_boolean.b1", "on"), None, now),
|
||||||
StoredState(State("input_boolean.b2", "on"), now),
|
StoredState(State("input_boolean.b2", "on"), None, now),
|
||||||
]
|
]
|
||||||
|
|
||||||
data = await RestoreStateData.async_get_instance(hass)
|
data = await RestoreStateData.async_get_instance(hass)
|
||||||
@ -225,15 +225,16 @@ async def test_dump_data(hass):
|
|||||||
data = await RestoreStateData.async_get_instance(hass)
|
data = await RestoreStateData.async_get_instance(hass)
|
||||||
now = dt_util.utcnow()
|
now = dt_util.utcnow()
|
||||||
data.last_states = {
|
data.last_states = {
|
||||||
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), now),
|
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now),
|
||||||
"input_boolean.b1": StoredState(State("input_boolean.b1", "off"), now),
|
"input_boolean.b1": StoredState(State("input_boolean.b1", "off"), None, now),
|
||||||
"input_boolean.b2": StoredState(State("input_boolean.b2", "off"), now),
|
"input_boolean.b2": StoredState(State("input_boolean.b2", "off"), None, now),
|
||||||
"input_boolean.b3": StoredState(State("input_boolean.b3", "off"), now),
|
"input_boolean.b3": StoredState(State("input_boolean.b3", "off"), None, now),
|
||||||
"input_boolean.b4": StoredState(
|
"input_boolean.b4": StoredState(
|
||||||
State("input_boolean.b4", "off"),
|
State("input_boolean.b4", "off"),
|
||||||
|
None,
|
||||||
datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC),
|
datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC),
|
||||||
),
|
),
|
||||||
"input_boolean.b5": StoredState(State("input_boolean.b5", "off"), now),
|
"input_boolean.b5": StoredState(State("input_boolean.b5", "off"), None, now),
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
|
@ -5,6 +5,7 @@ Call init before using it in your tests to ensure clean test data.
|
|||||||
"""
|
"""
|
||||||
from homeassistant.components.sensor import (
|
from homeassistant.components.sensor import (
|
||||||
DEVICE_CLASSES,
|
DEVICE_CLASSES,
|
||||||
|
RestoreSensor,
|
||||||
SensorDeviceClass,
|
SensorDeviceClass,
|
||||||
SensorEntity,
|
SensorEntity,
|
||||||
)
|
)
|
||||||
@ -109,3 +110,17 @@ class MockSensor(MockEntity, SensorEntity):
|
|||||||
def state_class(self):
|
def state_class(self):
|
||||||
"""Return the state class of this sensor."""
|
"""Return the state class of this sensor."""
|
||||||
return self._handle("state_class")
|
return self._handle("state_class")
|
||||||
|
|
||||||
|
|
||||||
|
class MockRestoreSensor(MockSensor, RestoreSensor):
|
||||||
|
"""Mock RestoreSensor class."""
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Restore native_value and native_unit_of_measurement."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
if (last_sensor_data := await self.async_get_last_sensor_data()) is None:
|
||||||
|
return
|
||||||
|
self._values["native_value"] = last_sensor_data.native_value
|
||||||
|
self._values[
|
||||||
|
"native_unit_of_measurement"
|
||||||
|
] = last_sensor_data.native_unit_of_measurement
|
||||||
|
Loading…
x
Reference in New Issue
Block a user