diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index c75d9c840ed..97069913c80 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -2,7 +2,7 @@ import asyncio from datetime import datetime, timedelta import logging -from typing import Any, Awaitable, Dict, List, Optional, Set, cast +from typing import Any, Dict, List, Optional, Set, cast from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP from homeassistant.core import ( @@ -17,6 +17,7 @@ from homeassistant.helpers import entity_registry from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.singleton import singleton from homeassistant.helpers.storage import Store import homeassistant.util.dt as dt_util @@ -63,45 +64,39 @@ class RestoreStateData: @classmethod async def async_get_instance(cls, hass: HomeAssistant) -> "RestoreStateData": """Get the singleton instance of this data helper.""" - task = hass.data.get(DATA_RESTORE_STATE_TASK) - if task is None: + @singleton(DATA_RESTORE_STATE_TASK) + async def load_instance(hass: HomeAssistant) -> "RestoreStateData": + """Get the singleton instance of this data helper.""" + data = cls(hass) - async def load_instance(hass: HomeAssistant) -> "RestoreStateData": - """Set up the restore state helper.""" - data = cls(hass) + try: + stored_states = await data.store.async_load() + except HomeAssistantError as exc: + _LOGGER.error("Error loading last states", exc_info=exc) + stored_states = None - try: - stored_states = await data.store.async_load() - except HomeAssistantError as exc: - _LOGGER.error("Error loading last states", exc_info=exc) - stored_states = None + if stored_states is None: + _LOGGER.debug("Not creating cache - no saved states found") + data.last_states = {} + else: + data.last_states = { + item["state"]["entity_id"]: StoredState.from_dict(item) + for item in stored_states + if valid_entity_id(item["state"]["entity_id"]) + } + _LOGGER.debug("Created cache with %s", list(data.last_states)) - if stored_states is None: - _LOGGER.debug("Not creating cache - no saved states found") - data.last_states = {} - else: - data.last_states = { - item["state"]["entity_id"]: StoredState.from_dict(item) - for item in stored_states - if valid_entity_id(item["state"]["entity_id"]) - } - _LOGGER.debug("Created cache with %s", list(data.last_states)) + if hass.state == CoreState.running: + data.async_setup_dump() + else: + hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_START, data.async_setup_dump + ) - if hass.state == CoreState.running: - data.async_setup_dump() - else: - hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_START, data.async_setup_dump - ) + return data - return data - - task = hass.data[DATA_RESTORE_STATE_TASK] = hass.async_create_task( - load_instance(hass) - ) - - return await cast(Awaitable["RestoreStateData"], task) + return cast(RestoreStateData, await load_instance(hass)) def __init__(self, hass: HomeAssistant) -> None: """Initialize the restore state data class.""" diff --git a/tests/common.py b/tests/common.py index 16f349de800..e2e183061a7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -818,11 +818,7 @@ def mock_restore_cache(hass, states): _LOGGER.debug("Restore cache: %s", data.last_states) assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" - async def get_restore_state_data() -> restore_state.RestoreStateData: - return data - - # Patch the singleton task in hass.data to return our new RestoreStateData - hass.data[key] = hass.async_create_task(get_restore_state_data()) + hass.data[key] = data class MockEntity(entity.Entity): diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index 7866662266d..15eed1c7e19 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -27,6 +27,7 @@ async def test_caching_data(hass): ] data = await RestoreStateData.async_get_instance(hass) + await hass.async_block_till_done() await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load @@ -41,6 +42,7 @@ async def test_caching_data(hass): "homeassistant.helpers.restore_state.Store.async_save" ) as mock_write_data: state = await entity.async_get_last_state() + await hass.async_block_till_done() assert state is not None assert state.entity_id == "input_boolean.b1" @@ -61,6 +63,7 @@ async def test_hass_starting(hass): ] data = await RestoreStateData.async_get_instance(hass) + await hass.async_block_till_done() await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load @@ -76,6 +79,7 @@ async def test_hass_starting(hass): "homeassistant.helpers.restore_state.Store.async_save" ) as mock_write_data, patch.object(hass.states, "async_all", return_value=states): state = await entity.async_get_last_state() + await hass.async_block_till_done() assert state is not None assert state.entity_id == "input_boolean.b1"