diff --git a/homeassistant/components/number/__init__.py b/homeassistant/components/number/__init__.py index 75f98447865..5a0bf9947f3 100644 --- a/homeassistant/components/number/__init__.py +++ b/homeassistant/components/number/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Callable from contextlib import suppress -from dataclasses import dataclass +import dataclasses from datetime import timedelta import inspect import logging @@ -22,6 +22,7 @@ from homeassistant.helpers.config_validation import ( # noqa: F401 ) from homeassistant.helpers.entity import Entity, EntityDescription from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity from homeassistant.helpers.typing import ConfigType from homeassistant.util import temperature as temperature_util @@ -112,7 +113,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return await component.async_unload_entry(entry) -@dataclass +@dataclasses.dataclass class NumberEntityDescription(EntityDescription): """A class that describes number entities.""" @@ -324,7 +325,7 @@ class NumberEntity(Entity): @property def native_value(self) -> float | None: - """Return the value reported by the sensor.""" + """Return the value reported by the number.""" return self._attr_native_value @property @@ -419,3 +420,53 @@ class NumberEntity(Entity): type(self), report_issue, ) + + +@dataclasses.dataclass +class NumberExtraStoredData(ExtraStoredData): + """Object to hold extra stored data.""" + + native_max_value: float | None + native_min_value: float | None + native_step: float | None + native_unit_of_measurement: str | None + native_value: float | None + + def as_dict(self) -> dict[str, Any]: + """Return a dict representation of the number data.""" + return dataclasses.asdict(self) + + @classmethod + def from_dict(cls, restored: dict[str, Any]) -> NumberExtraStoredData | None: + """Initialize a stored number state from a dict.""" + try: + return cls( + restored["native_max_value"], + restored["native_min_value"], + restored["native_step"], + restored["native_unit_of_measurement"], + restored["native_value"], + ) + except KeyError: + return None + + +class RestoreNumber(NumberEntity, RestoreEntity): + """Mixin class for restoring previous number state.""" + + @property + def extra_restore_state_data(self) -> NumberExtraStoredData: + """Return number specific state data to be restored.""" + return NumberExtraStoredData( + self.native_max_value, + self.native_min_value, + self.native_step, + self.native_unit_of_measurement, + self.native_value, + ) + + async def async_get_last_number_data(self) -> NumberExtraStoredData | None: + """Restore native_*.""" + if (restored_last_extra_data := await self.async_get_last_extra_data()) is None: + return None + return NumberExtraStoredData.from_dict(restored_last_extra_data.as_dict()) diff --git a/tests/components/number/test_init.py b/tests/components/number/test_init.py index ccc6f0da0c5..0df7f79e4a4 100644 --- a/tests/components/number/test_init.py +++ b/tests/components/number/test_init.py @@ -21,10 +21,13 @@ from homeassistant.const import ( TEMP_CELSIUS, TEMP_FAHRENHEIT, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, State +from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY from homeassistant.setup import async_setup_component from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM +from tests.common import mock_restore_cache_with_extra_data + class MockDefaultNumberEntity(NumberEntity): """Mock NumberEntity device to use in tests. @@ -570,3 +573,115 @@ async def test_temperature_conversion( state = hass.states.get(entity0.entity_id) assert float(state.state) == pytest.approx(float(state_max_value), rel=0.1) + + +RESTORE_DATA = { + "native_max_value": 200.0, + "native_min_value": -10.0, + "native_step": 2.0, + "native_unit_of_measurement": "°F", + "native_value": 123.0, +} + + +async def test_restore_number_save_state( + hass, + hass_storage, + enable_custom_integrations, +): + """Test RestoreNumber.""" + platform = getattr(hass.components, "test.number") + platform.init(empty=True) + platform.ENTITIES.append( + platform.MockRestoreNumber( + name="Test", + native_max_value=200.0, + native_min_value=-10.0, + native_step=2.0, + native_unit_of_measurement=TEMP_FAHRENHEIT, + native_value=123.0, + device_class=NumberDeviceClass.TEMPERATURE, + ) + ) + + entity0 = platform.ENTITIES[0] + assert await async_setup_component(hass, "number", {"number": {"platform": "test"}}) + await hass.async_block_till_done() + + # Trigger saving state + await hass.async_stop() + + assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 + state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] + assert state["entity_id"] == entity0.entity_id + extra_data = hass_storage[RESTORE_STATE_KEY]["data"][0]["extra_data"] + assert extra_data == RESTORE_DATA + assert type(extra_data["native_value"]) == float + + +@pytest.mark.parametrize( + "native_max_value, native_min_value, native_step, native_value, native_value_type, extra_data, device_class, uom", + [ + ( + 200.0, + -10.0, + 2.0, + 123.0, + float, + RESTORE_DATA, + NumberDeviceClass.TEMPERATURE, + "°F", + ), + (100.0, 0.0, None, None, type(None), None, None, None), + (100.0, 0.0, None, None, type(None), {}, None, None), + (100.0, 0.0, None, None, type(None), {"beer": 123}, None, None), + ( + 100.0, + 0.0, + None, + None, + type(None), + {"native_unit_of_measurement": "°F", "native_value": {}}, + None, + None, + ), + ], +) +async def test_restore_number_restore_state( + hass, + enable_custom_integrations, + hass_storage, + native_max_value, + native_min_value, + native_step, + native_value, + native_value_type, + extra_data, + device_class, + uom, +): + """Test RestoreNumber.""" + mock_restore_cache_with_extra_data(hass, ((State("number.test", ""), extra_data),)) + + platform = getattr(hass.components, "test.number") + platform.init(empty=True) + platform.ENTITIES.append( + platform.MockRestoreNumber( + device_class=device_class, + name="Test", + native_value=None, + ) + ) + + entity0 = platform.ENTITIES[0] + assert await async_setup_component(hass, "number", {"number": {"platform": "test"}}) + await hass.async_block_till_done() + + assert hass.states.get(entity0.entity_id) + + assert entity0.native_max_value == native_max_value + assert entity0.native_min_value == native_min_value + assert entity0.native_step == native_step + assert entity0.native_value == native_value + assert type(entity0.native_value) == native_value_type + assert entity0.native_unit_of_measurement == uom diff --git a/tests/testing_config/custom_components/test/number.py b/tests/testing_config/custom_components/test/number.py index ac397a4d42b..094698923f4 100644 --- a/tests/testing_config/custom_components/test/number.py +++ b/tests/testing_config/custom_components/test/number.py @@ -3,7 +3,7 @@ Provide a mock number platform. Call init before using it in your tests to ensure clean test data. """ -from homeassistant.components.number import NumberEntity +from homeassistant.components.number import NumberEntity, RestoreNumber from tests.common import MockEntity @@ -37,7 +37,7 @@ class MockNumberEntity(MockEntity, NumberEntity): @property def native_value(self): - """Return the native value of this sensor.""" + """Return the native value of this number.""" return self._handle("native_value") def set_native_value(self, value: float) -> None: @@ -45,6 +45,23 @@ class MockNumberEntity(MockEntity, NumberEntity): self._values["native_value"] = value +class MockRestoreNumber(MockNumberEntity, RestoreNumber): + """Mock RestoreNumber class.""" + + async def async_added_to_hass(self) -> None: + """Restore native_*.""" + await super().async_added_to_hass() + if (last_number_data := await self.async_get_last_number_data()) is None: + return + self._values["native_max_value"] = last_number_data.native_max_value + self._values["native_min_value"] = last_number_data.native_min_value + self._values["native_step"] = last_number_data.native_step + self._values[ + "native_unit_of_measurement" + ] = last_number_data.native_unit_of_measurement + self._values["native_value"] = last_number_data.native_value + + class LegacyMockNumberEntity(MockEntity, NumberEntity): """Mock Number class using deprecated features."""