diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index da4d2bacf15..f1e74e26908 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -6,16 +6,10 @@ from datetime import datetime, timedelta import logging from typing import Any, cast -from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP -from homeassistant.core import ( - CoreState, - HomeAssistant, - State, - callback, - valid_entity_id, -) +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import HomeAssistant, State, callback, valid_entity_id from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import entity_registry +from homeassistant.helpers import entity_registry, start from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.json import JSONEncoder @@ -63,42 +57,36 @@ class StoredState: class RestoreStateData: """Helper class for managing the helper saved data.""" - @classmethod - async def async_get_instance(cls, hass: HomeAssistant) -> RestoreStateData: + @staticmethod + @singleton(DATA_RESTORE_STATE_TASK) + async def async_get_instance(hass: HomeAssistant) -> RestoreStateData: """Get the singleton instance of this data helper.""" + data = RestoreStateData(hass) - @singleton(DATA_RESTORE_STATE_TASK) - async def load_instance(hass: HomeAssistant) -> RestoreStateData: - """Get the singleton instance of this data helper.""" - data = cls(hass) + try: + stored_states = await data.store.async_load() + except HomeAssistantError as exc: + _LOGGER.error("Error loading last states", exc_info=exc) + stored_states = None - try: - stored_states = await data.store.async_load() - except HomeAssistantError as exc: - _LOGGER.error("Error loading last states", exc_info=exc) - stored_states = None + if stored_states is None: + _LOGGER.debug("Not creating cache - no saved states found") + data.last_states = {} + else: + data.last_states = { + item["state"]["entity_id"]: StoredState.from_dict(item) + for item in stored_states + if valid_entity_id(item["state"]["entity_id"]) + } + _LOGGER.debug("Created cache with %s", list(data.last_states)) - if stored_states is None: - _LOGGER.debug("Not creating cache - no saved states found") - data.last_states = {} - else: - data.last_states = { - item["state"]["entity_id"]: StoredState.from_dict(item) - for item in stored_states - if valid_entity_id(item["state"]["entity_id"]) - } - _LOGGER.debug("Created cache with %s", list(data.last_states)) + async def hass_start(hass: HomeAssistant) -> None: + """Start the restore state task.""" + data.async_setup_dump() - if hass.state == CoreState.running: - data.async_setup_dump() - else: - hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_START, data.async_setup_dump - ) + start.async_at_start(hass, hass_start) - return data - - return cast(RestoreStateData, await load_instance(hass)) + return data @classmethod async def async_save_persistent_states(cls, hass: HomeAssistant) -> None: @@ -269,7 +257,9 @@ class RestoreEntity(Entity): # Return None if this entity isn't added to hass yet _LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable] return None - data = await RestoreStateData.async_get_instance(self.hass) + data = cast( + RestoreStateData, await RestoreStateData.async_get_instance(self.hass) + ) if self.entity_id not in data.last_states: return None return data.last_states[self.entity_id].state diff --git a/homeassistant/helpers/singleton.py b/homeassistant/helpers/singleton.py index a48ea5d64f0..a3cde0b2f27 100644 --- a/homeassistant/helpers/singleton.py +++ b/homeassistant/helpers/singleton.py @@ -26,31 +26,27 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]: @bind_hass @functools.wraps(func) def wrapped(hass: HomeAssistant) -> T: - obj: T | None = hass.data.get(data_key) - if obj is None: - obj = hass.data[data_key] = func(hass) - return obj + if data_key not in hass.data: + hass.data[data_key] = func(hass) + return cast(T, hass.data[data_key]) return wrapped @bind_hass @functools.wraps(func) async def async_wrapped(hass: HomeAssistant) -> T: - obj_or_evt = hass.data.get(data_key) - - if not obj_or_evt: + if data_key not in hass.data: evt = hass.data[data_key] = asyncio.Event() - result = await func(hass) - hass.data[data_key] = result evt.set() return cast(T, result) + obj_or_evt = hass.data[data_key] + if isinstance(obj_or_evt, asyncio.Event): - evt = obj_or_evt - await evt.wait() - return cast(T, hass.data.get(data_key)) + await obj_or_evt.wait() + return cast(T, hass.data[data_key]) return cast(T, obj_or_evt) diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index d138a5381da..79719b75326 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -32,7 +32,7 @@ async def test_caching_data(hass): await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load - hass.data[DATA_RESTORE_STATE_TASK] = None + hass.data.pop(DATA_RESTORE_STATE_TASK) entity = RestoreEntity() entity.hass = hass @@ -59,7 +59,7 @@ async def test_periodic_write(hass): await data.store.async_save([]) # Emulate a fresh load - hass.data[DATA_RESTORE_STATE_TASK] = None + hass.data.pop(DATA_RESTORE_STATE_TASK) entity = RestoreEntity() entity.hass = hass @@ -105,7 +105,7 @@ async def test_save_persistent_states(hass): await data.store.async_save([]) # Emulate a fresh load - hass.data[DATA_RESTORE_STATE_TASK] = None + hass.data.pop(DATA_RESTORE_STATE_TASK) entity = RestoreEntity() entity.hass = hass @@ -170,7 +170,8 @@ async def test_hass_starting(hass): await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load - hass.data[DATA_RESTORE_STATE_TASK] = None + hass.state = CoreState.not_running + hass.data.pop(DATA_RESTORE_STATE_TASK) entity = RestoreEntity() entity.hass = hass diff --git a/tests/helpers/test_singleton.py b/tests/helpers/test_singleton.py index c695efd94a8..1d4f496a794 100644 --- a/tests/helpers/test_singleton.py +++ b/tests/helpers/test_singleton.py @@ -12,29 +12,33 @@ def mock_hass(): return Mock(data={}) -async def test_singleton_async(mock_hass): +@pytest.mark.parametrize("result", (object(), {}, [])) +async def test_singleton_async(mock_hass, result): """Test singleton with async function.""" @singleton.singleton("test_key") async def something(hass): - return object() + return result result1 = await something(mock_hass) result2 = await something(mock_hass) + assert result1 is result assert result1 is result2 assert "test_key" in mock_hass.data assert mock_hass.data["test_key"] is result1 -def test_singleton(mock_hass): +@pytest.mark.parametrize("result", (object(), {}, [])) +def test_singleton(mock_hass, result): """Test singleton with function.""" @singleton.singleton("test_key") def something(hass): - return object() + return result result1 = something(mock_hass) result2 = something(mock_hass) + assert result1 is result assert result1 is result2 assert "test_key" in mock_hass.data assert mock_hass.data["test_key"] is result1