diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index 51f1bd76c2a..cabaf64d859 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -1,7 +1,7 @@ """Support for restoring entity states on startup.""" import asyncio import logging -from datetime import timedelta +from datetime import timedelta, datetime from typing import Any, Dict, List, Set, Optional # noqa pylint_disable=unused-import from homeassistant.core import HomeAssistant, callback, State, CoreState @@ -28,6 +28,32 @@ STATE_DUMP_INTERVAL = timedelta(minutes=15) STATE_EXPIRATION = timedelta(days=7) +class StoredState: + """Object to represent a stored state.""" + + def __init__(self, state: State, last_seen: datetime) -> None: + """Initialize a new stored state.""" + self.state = state + self.last_seen = last_seen + + def as_dict(self) -> Dict: + """Return a dict representation of the stored state.""" + return { + 'state': self.state.as_dict(), + 'last_seen': self.last_seen, + } + + @classmethod + def from_dict(cls, json_dict: Dict) -> 'StoredState': + """Initialize a stored state from a dict.""" + last_seen = json_dict['last_seen'] + + if isinstance(last_seen, str): + last_seen = dt_util.parse_datetime(last_seen) + + return cls(State.from_dict(json_dict['state']), last_seen) + + class RestoreStateData(): """Helper class for managing the helper saved data.""" @@ -43,18 +69,18 @@ class RestoreStateData(): data = cls(hass) try: - states = await data.store.async_load() + stored_states = await data.store.async_load() except HomeAssistantError as exc: _LOGGER.error("Error loading last states", exc_info=exc) - states = None + stored_states = None - if states is None: + if stored_states is None: _LOGGER.debug('Not creating cache - no saved states found') data.last_states = {} else: data.last_states = { - state['entity_id']: State.from_dict(state) - for state in states} + item['state']['entity_id']: StoredState.from_dict(item) + for item in stored_states} _LOGGER.debug( 'Created cache with %s', list(data.last_states)) @@ -74,46 +100,49 @@ class RestoreStateData(): def __init__(self, hass: HomeAssistant) -> None: """Initialize the restore state data class.""" self.hass = hass # type: HomeAssistant - self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY, - encoder=JSONEncoder) # type: Store - self.last_states = {} # type: Dict[str, State] + self.store = Store( + hass, STORAGE_VERSION, STORAGE_KEY, + encoder=JSONEncoder) # type: Store + self.last_states = {} # type: Dict[str, StoredState] self.entity_ids = set() # type: Set[str] - def async_get_states(self) -> List[State]: + def async_get_stored_states(self) -> List[StoredState]: """Get the set of states which should be stored. This includes the states of all registered entities, as well as the stored states from the previous run, which have not been created as entities on this run, and have not expired. """ + now = dt_util.utcnow() all_states = self.hass.states.async_all() current_entity_ids = set(state.entity_id for state in all_states) # Start with the currently registered states - states = [state for state in all_states - if state.entity_id in self.entity_ids] + stored_states = [StoredState(state, now) for state in all_states + if state.entity_id in self.entity_ids] - expiration_time = dt_util.utcnow() - STATE_EXPIRATION + expiration_time = now - STATE_EXPIRATION - for entity_id, state in self.last_states.items(): + for entity_id, stored_state in self.last_states.items(): # Don't save old states that have entities in the current run if entity_id in current_entity_ids: continue # Don't save old states that have expired - if state.last_updated < expiration_time: + if stored_state.last_seen < expiration_time: continue - states.append(state) + stored_states.append(stored_state) - return states + return stored_states async def async_dump_states(self) -> None: """Save the current state machine to storage.""" _LOGGER.debug("Dumping states") try: await self.store.async_save([ - state.as_dict() for state in self.async_get_states()]) + stored_state.as_dict() + for stored_state in self.async_get_stored_states()]) except HomeAssistantError as exc: _LOGGER.error("Error saving current states", exc_info=exc) @@ -172,4 +201,6 @@ class RestoreEntity(Entity): _LOGGER.warning("Cannot get last state. Entity not added to hass") return None data = await RestoreStateData.async_get_instance(self.hass) - return data.last_states.get(self.entity_id) + if self.entity_id not in data.last_states: + return None + return data.last_states[self.entity_id].state diff --git a/tests/common.py b/tests/common.py index 86bc0643d65..db7ce6e3a17 100644 --- a/tests/common.py +++ b/tests/common.py @@ -715,9 +715,11 @@ def mock_restore_cache(hass, states): """Mock the DATA_RESTORE_CACHE.""" key = restore_state.DATA_RESTORE_STATE_TASK data = restore_state.RestoreStateData(hass) + now = date_util.utcnow() data.last_states = { - state.entity_id: state for state in states} + state.entity_id: restore_state.StoredState(state, now) + for state in states} _LOGGER.debug('Restore cache: %s', data.last_states) assert len(data.last_states) == len(states), \ "Duplicate entity_id? {}".format(states) diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index 1ac48264d45..e6693d2cf61 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -6,7 +6,7 @@ from homeassistant.core import CoreState, State from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity import Entity from homeassistant.helpers.restore_state import ( - RestoreStateData, RestoreEntity, DATA_RESTORE_STATE_TASK) + RestoreStateData, RestoreEntity, StoredState, DATA_RESTORE_STATE_TASK) from homeassistant.util import dt as dt_util from asynctest import patch @@ -16,14 +16,15 @@ from tests.common import mock_coro async def test_caching_data(hass): """Test that we cache data.""" - states = [ - State('input_boolean.b0', 'on'), - State('input_boolean.b1', 'on'), - State('input_boolean.b2', 'on'), + now = dt_util.utcnow() + stored_states = [ + StoredState(State('input_boolean.b0', 'on'), now), + StoredState(State('input_boolean.b1', 'on'), now), + StoredState(State('input_boolean.b2', 'on'), now), ] data = await RestoreStateData.async_get_instance(hass) - await data.store.async_save([state.as_dict() for state in states]) + await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load hass.data[DATA_RESTORE_STATE_TASK] = None @@ -48,14 +49,15 @@ async def test_hass_starting(hass): """Test that we cache data.""" hass.state = CoreState.starting - states = [ - State('input_boolean.b0', 'on'), - State('input_boolean.b1', 'on'), - State('input_boolean.b2', 'on'), + now = dt_util.utcnow() + stored_states = [ + StoredState(State('input_boolean.b0', 'on'), now), + StoredState(State('input_boolean.b1', 'on'), now), + StoredState(State('input_boolean.b2', 'on'), now), ] data = await RestoreStateData.async_get_instance(hass) - await data.store.async_save([state.as_dict() for state in states]) + await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load hass.data[DATA_RESTORE_STATE_TASK] = None @@ -109,14 +111,15 @@ async def test_dump_data(hass): await entity.async_added_to_hass() data = await RestoreStateData.async_get_instance(hass) + now = dt_util.utcnow() data.last_states = { - 'input_boolean.b0': State('input_boolean.b0', 'off'), - 'input_boolean.b1': State('input_boolean.b1', 'off'), - 'input_boolean.b2': State('input_boolean.b2', 'off'), - 'input_boolean.b3': State('input_boolean.b3', 'off'), - 'input_boolean.b4': State( - 'input_boolean.b4', 'off', last_updated=datetime( - 1985, 10, 26, 1, 22, tzinfo=dt_util.UTC)), + 'input_boolean.b0': StoredState(State('input_boolean.b0', 'off'), now), + 'input_boolean.b1': StoredState(State('input_boolean.b1', 'off'), now), + 'input_boolean.b2': StoredState(State('input_boolean.b2', 'off'), now), + 'input_boolean.b3': StoredState(State('input_boolean.b3', 'off'), now), + 'input_boolean.b4': StoredState( + State('input_boolean.b4', 'off'), + datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC)), } with patch('homeassistant.helpers.restore_state.Store.async_save' @@ -134,10 +137,10 @@ async def test_dump_data(hass): # b3 should be written, since it is still not expired # b4 should not be written, since it is now expired assert len(written_states) == 2 - assert written_states[0]['entity_id'] == 'input_boolean.b1' - assert written_states[0]['state'] == 'on' - assert written_states[1]['entity_id'] == 'input_boolean.b3' - assert written_states[1]['state'] == 'off' + assert written_states[0]['state']['entity_id'] == 'input_boolean.b1' + assert written_states[0]['state']['state'] == 'on' + assert written_states[1]['state']['entity_id'] == 'input_boolean.b3' + assert written_states[1]['state']['state'] == 'off' # Test that removed entities are not persisted await entity.async_will_remove_from_hass() @@ -151,8 +154,8 @@ async def test_dump_data(hass): args = mock_write_data.mock_calls[0][1] written_states = args[0] assert len(written_states) == 1 - assert written_states[0]['entity_id'] == 'input_boolean.b3' - assert written_states[0]['state'] == 'off' + assert written_states[0]['state']['entity_id'] == 'input_boolean.b3' + assert written_states[0]['state']['state'] == 'off' async def test_dump_error(hass):