diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index e34e3c86324..ab3b93cf3c4 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -16,6 +16,7 @@ import homeassistant.util.dt as dt_util from . import start from .entity import Entity from .event import async_track_time_interval +from .frame import report from .json import JSONEncoder from .storage import Store @@ -96,7 +97,9 @@ class StoredState: async def async_load(hass: HomeAssistant) -> None: """Load the restore state task.""" - hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass) + restore_state = RestoreStateData(hass) + await restore_state.async_setup() + hass.data[DATA_RESTORE_STATE] = restore_state @callback @@ -108,25 +111,26 @@ def async_get(hass: HomeAssistant) -> RestoreStateData: class RestoreStateData: """Helper class for managing the helper saved data.""" - @staticmethod - async def async_get_instance(hass: HomeAssistant) -> RestoreStateData: - """Get the instance of this data helper.""" - data = RestoreStateData(hass) - await data.async_load() - - async def hass_start(hass: HomeAssistant) -> None: - """Start the restore state task.""" - data.async_setup_dump() - - start.async_at_start(hass, hass_start) - - return data - @classmethod async def async_save_persistent_states(cls, hass: HomeAssistant) -> None: """Dump states now.""" await async_get(hass).async_dump_states() + @classmethod + async def async_get_instance(cls, hass: HomeAssistant) -> RestoreStateData: + """Return the instance of this class.""" + # Nothing should actually be calling this anymore, but we'll keep it + # around for a while to avoid breaking custom components. + # + # In fact they should not be accessing this at all. + report( + "restore_state.RestoreStateData.async_get_instance is deprecated, " + "and not intended to be called by custom components; Please" + "refactor your code to use RestoreEntity instead;" + " restore_state.async_get(hass) can be used in the meantime", + ) + return async_get(hass) + def __init__(self, hass: HomeAssistant) -> None: """Initialize the restore state data class.""" self.hass: HomeAssistant = hass @@ -136,6 +140,16 @@ class RestoreStateData: self.last_states: dict[str, StoredState] = {} self.entities: dict[str, RestoreEntity] = {} + async def async_setup(self) -> None: + """Set up up the instance of this data helper.""" + await self.async_load() + + async def hass_start(hass: HomeAssistant) -> None: + """Start the restore state task.""" + self.async_setup_dump() + + start.async_at_start(self.hass, hass_start) + async def async_load(self) -> None: """Load the instance of this data helper.""" try: diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index cf6a078d137..b5ce7afade0 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -1,12 +1,19 @@ """The tests for the Restore component.""" +from collections.abc import Coroutine from datetime import datetime, timedelta +import logging from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch + +import pytest from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP from homeassistant.core import CoreState, HomeAssistant, State from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity import Entity +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.reload import async_get_platform_without_config_entry from homeassistant.helpers.restore_state import ( DATA_RESTORE_STATE, STORAGE_KEY, @@ -16,9 +23,20 @@ from homeassistant.helpers.restore_state import ( async_get, async_load, ) +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.util import dt as dt_util -from tests.common import async_fire_time_changed +from tests.common import ( + MockModule, + MockPlatform, + async_fire_time_changed, + mock_entity_platform, + mock_integration, +) + +_LOGGER = logging.getLogger(__name__) +DOMAIN = "test_domain" +PLATFORM = "test_platform" async def test_caching_data(hass: HomeAssistant) -> None: @@ -68,6 +86,20 @@ async def test_caching_data(hass: HomeAssistant) -> None: assert mock_write_data.called +async def test_async_get_instance_backwards_compatibility(hass: HomeAssistant) -> None: + """Test async_get_instance backwards compatibility.""" + await async_load(hass) + data = async_get(hass) + # When called from core it should raise + with pytest.raises(RuntimeError): + await RestoreStateData.async_get_instance(hass) + + # When called from a component it should not raise + # but it should report + with patch("homeassistant.helpers.restore_state.report"): + assert data is await RestoreStateData.async_get_instance(hass) + + async def test_periodic_write(hass: HomeAssistant) -> None: """Test that we write periodiclly but not after stop.""" data = async_get(hass) @@ -401,3 +433,89 @@ async def test_restoring_invalid_entity_id( state = await entity.async_get_last_state() assert state is None + + +async def test_restore_entity_end_to_end( + hass: HomeAssistant, hass_storage: dict[str, Any] +) -> None: + """Test restoring an entity end-to-end.""" + component_setup = Mock(return_value=True) + + setup_called = [] + + entity_id = "test_domain.unnamed_device" + data = async_get(hass) + now = dt_util.utcnow() + data.last_states = { + entity_id: StoredState(State(entity_id, "stored"), None, now), + } + + class MockRestoreEntity(RestoreEntity): + """Mock restore entity.""" + + def __init__(self): + """Initialize the mock entity.""" + self._state: str | None = None + + @property + def state(self): + """Return the state.""" + return self._state + + async def async_added_to_hass(self) -> Coroutine[Any, Any, None]: + """Run when entity about to be added to hass.""" + await super().async_added_to_hass() + self._state = (await self.async_get_last_state()).state + + async def async_setup_platform( + hass: HomeAssistant, + config: ConfigType, + async_add_entities: AddEntitiesCallback, + discovery_info: DiscoveryInfoType | None = None, + ) -> None: + """Set up the test platform.""" + async_add_entities([MockRestoreEntity()]) + setup_called.append(True) + + mock_integration(hass, MockModule(DOMAIN, setup=component_setup)) + mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN])) + + mock_platform = MockPlatform(async_setup_platform=async_setup_platform) + mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + + await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}}) + await hass.async_block_till_done() + assert component_setup.called + + assert f"{DOMAIN}.{PLATFORM}" in hass.config.components + assert len(setup_called) == 1 + + platform = async_get_platform_without_config_entry(hass, PLATFORM, DOMAIN) + assert platform.platform_name == PLATFORM + assert platform.domain == DOMAIN + assert hass.states.get(entity_id).state == "stored" + + await data.async_dump_states() + await hass.async_block_till_done() + + storage_data = hass_storage[STORAGE_KEY]["data"] + assert len(storage_data) == 1 + assert storage_data[0]["state"]["entity_id"] == entity_id + assert storage_data[0]["state"]["state"] == "stored" + + await platform.async_reset() + + assert hass.states.get(entity_id) is None + + # Make sure the entity still gets saved to restore state + # even though the platform has been reset since it should + # not be expired yet. + await data.async_dump_states() + await hass.async_block_till_done() + + storage_data = hass_storage[STORAGE_KEY]["data"] + assert len(storage_data) == 1 + assert storage_data[0]["state"]["entity_id"] == entity_id + assert storage_data[0]["state"]["state"] == "stored"