diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index cf492ab38bd..bdab888842a 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -19,6 +19,7 @@ from .entity import Entity from .event import async_track_time_interval from .frame import report from .json import JSONEncoder +from .singleton import singleton from .storage import Store DATA_RESTORE_STATE: HassKey[RestoreStateData] = HassKey("restore_state") @@ -97,15 +98,14 @@ class StoredState: async def async_load(hass: HomeAssistant) -> None: """Load the restore state task.""" - restore_state = RestoreStateData(hass) - await restore_state.async_setup() - hass.data[DATA_RESTORE_STATE] = restore_state + await async_get(hass).async_setup() @callback +@singleton(DATA_RESTORE_STATE) def async_get(hass: HomeAssistant) -> RestoreStateData: """Get the restore state data helper.""" - return hass.data[DATA_RESTORE_STATE] + return RestoreStateData(hass) class RestoreStateData: diff --git a/tests/common.py b/tests/common.py index 4ed38e22a0b..b25d730a8cd 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1175,6 +1175,7 @@ def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None: _LOGGER.debug("Restore cache: %s", data.last_states) assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" + restore_state.async_get.cache_clear() hass.data[key] = data @@ -1202,6 +1203,7 @@ def mock_restore_cache_with_extra_data( _LOGGER.debug("Restore cache: %s", data.last_states) assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" + restore_state.async_get.cache_clear() hass.data[key] = data