From fba826ae9ebdbdb6c116ec5a69ed8ae252d153b6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 30 May 2023 20:48:17 -0500 Subject: [PATCH] Migrate restore_state helper to use registry loading pattern (#93773) * Migrate restore_state helper to use registry loading pattern As more entities have started using restore_state over time, it has become a startup bottleneck as each entity being added is creating a task to load restore state data that is already loaded since it is a singleton We now use the same pattern as the registry helpers * fix refactoring error -- guess I am tired * fixes * fix tests * fix more * fix more * fix zha tests * fix zha tests * comments * fix error * add missing coverage * s/DATA_RESTORE_STATE_TASK/DATA_RESTORE_STATE/g --- homeassistant/bootstrap.py | 2 + homeassistant/helpers/restore_state.py | 84 +++++++++++----------- tests/common.py | 35 ++++++++- tests/components/number/test_init.py | 7 +- tests/components/sensor/test_init.py | 7 +- tests/components/text/test_init.py | 7 +- tests/components/timer/test_init.py | 34 +++------ tests/components/zha/test_binary_sensor.py | 5 ++ tests/components/zha/test_select.py | 3 + tests/components/zha/test_sensor.py | 4 ++ tests/helpers/test_restore_state.py | 45 ++++++++---- 11 files changed, 147 insertions(+), 86 deletions(-) diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 67b62da94d5..7e5aa853f12 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -32,6 +32,7 @@ from .helpers import ( entity_registry, issue_registry, recorder, + restore_state, template, ) from .helpers.dispatcher import async_dispatcher_send @@ -248,6 +249,7 @@ async def load_registries(hass: core.HomeAssistant) -> None: issue_registry.async_load(hass), hass.async_add_executor_job(_cache_uname_processor), template.async_load_custom_templates(hass), + restore_state.async_load(hass), ) diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index d31c12d0fd5..e34e3c86324 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -2,7 +2,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -import asyncio from datetime import datetime, timedelta import logging from typing import Any, cast @@ -18,10 +17,9 @@ from . import start from .entity import Entity from .event import async_track_time_interval from .json import JSONEncoder -from .singleton import singleton from .storage import Store -DATA_RESTORE_STATE_TASK = "restore_state_task" +DATA_RESTORE_STATE = "restore_state" _LOGGER = logging.getLogger(__name__) @@ -96,31 +94,25 @@ class StoredState: ) +async def async_load(hass: HomeAssistant) -> None: + """Load the restore state task.""" + hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass) + + +@callback +def async_get(hass: HomeAssistant) -> RestoreStateData: + """Get the restore state data helper.""" + return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE]) + + class RestoreStateData: """Helper class for managing the helper saved data.""" @staticmethod - @singleton(DATA_RESTORE_STATE_TASK) async def async_get_instance(hass: HomeAssistant) -> RestoreStateData: - """Get the singleton instance of this data helper.""" + """Get the instance of this data helper.""" data = RestoreStateData(hass) - - try: - stored_states = await data.store.async_load() - except HomeAssistantError as exc: - _LOGGER.error("Error loading last states", exc_info=exc) - stored_states = None - - if stored_states is None: - _LOGGER.debug("Not creating cache - no saved states found") - data.last_states = {} - else: - data.last_states = { - item["state"]["entity_id"]: StoredState.from_dict(item) - for item in stored_states - if valid_entity_id(item["state"]["entity_id"]) - } - _LOGGER.debug("Created cache with %s", list(data.last_states)) + await data.async_load() async def hass_start(hass: HomeAssistant) -> None: """Start the restore state task.""" @@ -133,8 +125,7 @@ class RestoreStateData: @classmethod async def async_save_persistent_states(cls, hass: HomeAssistant) -> None: """Dump states now.""" - data = await cls.async_get_instance(hass) - await data.async_dump_states() + await async_get(hass).async_dump_states() def __init__(self, hass: HomeAssistant) -> None: """Initialize the restore state data class.""" @@ -145,6 +136,25 @@ class RestoreStateData: self.last_states: dict[str, StoredState] = {} self.entities: dict[str, RestoreEntity] = {} + async def async_load(self) -> None: + """Load the instance of this data helper.""" + try: + stored_states = await self.store.async_load() + except HomeAssistantError as exc: + _LOGGER.error("Error loading last states", exc_info=exc) + stored_states = None + + if stored_states is None: + _LOGGER.debug("Not creating cache - no saved states found") + self.last_states = {} + else: + self.last_states = { + item["state"]["entity_id"]: StoredState.from_dict(item) + for item in stored_states + if valid_entity_id(item["state"]["entity_id"]) + } + _LOGGER.debug("Created cache with %s", list(self.last_states)) + @callback def async_get_stored_states(self) -> list[StoredState]: """Get the set of states which should be stored. @@ -288,21 +298,18 @@ class RestoreEntity(Entity): async def async_internal_added_to_hass(self) -> None: """Register this entity as a restorable entity.""" - _, data = await asyncio.gather( - super().async_internal_added_to_hass(), - RestoreStateData.async_get_instance(self.hass), - ) - data.async_restore_entity_added(self) + await super().async_internal_added_to_hass() + async_get(self.hass).async_restore_entity_added(self) async def async_internal_will_remove_from_hass(self) -> None: """Run when entity will be removed from hass.""" - _, data = await asyncio.gather( - super().async_internal_will_remove_from_hass(), - RestoreStateData.async_get_instance(self.hass), + async_get(self.hass).async_restore_entity_removed( + self.entity_id, self.extra_restore_state_data ) - data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data) + await super().async_internal_will_remove_from_hass() - async def _async_get_restored_data(self) -> StoredState | None: + @callback + def _async_get_restored_data(self) -> StoredState | None: """Get data stored for an entity, if any.""" if self.hass is None or self.entity_id is None: # Return None if this entity isn't added to hass yet @@ -310,20 +317,17 @@ class RestoreEntity(Entity): "Cannot get last state. Entity not added to hass" ) return None - data = await RestoreStateData.async_get_instance(self.hass) - if self.entity_id not in data.last_states: - return None - return data.last_states[self.entity_id] + return async_get(self.hass).last_states.get(self.entity_id) async def async_get_last_state(self) -> State | None: """Get the entity state from the previous run.""" - if (stored_state := await self._async_get_restored_data()) is None: + if (stored_state := self._async_get_restored_data()) is None: return None return stored_state.state async def async_get_last_extra_data(self) -> ExtraStoredData | None: """Get the entity specific state data from the previous run.""" - if (stored_state := await self._async_get_restored_data()) is None: + if (stored_state := self._async_get_restored_data()) is None: return None return stored_state.extra_data diff --git a/tests/common.py b/tests/common.py index f0d7a8de3c3..ca164fcaaf8 100644 --- a/tests/common.py +++ b/tests/common.py @@ -61,6 +61,7 @@ from homeassistant.helpers import ( issue_registry as ir, recorder as recorder_helper, restore_state, + restore_state as rs, storage, ) from homeassistant.helpers.dispatcher import async_dispatcher_connect @@ -251,12 +252,20 @@ async def async_test_home_assistant(event_loop, load_registries=True): # Load the registries entity.async_setup(hass) if load_registries: - with patch("homeassistant.helpers.storage.Store.async_load", return_value=None): + with patch( + "homeassistant.helpers.storage.Store.async_load", return_value=None + ), patch( + "homeassistant.helpers.restore_state.RestoreStateData.async_setup_dump", + return_value=None, + ), patch( + "homeassistant.helpers.restore_state.start.async_at_start" + ): await asyncio.gather( ar.async_load(hass), dr.async_load(hass), er.async_load(hass), ir.async_load(hass), + rs.async_load(hass), ) hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None @@ -1010,7 +1019,7 @@ def init_recorder_component(hass, add_config=None, db_url="sqlite://"): def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None: """Mock the DATA_RESTORE_CACHE.""" - key = restore_state.DATA_RESTORE_STATE_TASK + key = restore_state.DATA_RESTORE_STATE data = restore_state.RestoreStateData(hass) now = dt_util.utcnow() @@ -1037,7 +1046,7 @@ def mock_restore_cache_with_extra_data( hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]] ) -> None: """Mock the DATA_RESTORE_CACHE.""" - key = restore_state.DATA_RESTORE_STATE_TASK + key = restore_state.DATA_RESTORE_STATE data = restore_state.RestoreStateData(hass) now = dt_util.utcnow() @@ -1060,6 +1069,26 @@ def mock_restore_cache_with_extra_data( hass.data[key] = data +async def async_mock_restore_state_shutdown_restart( + hass: HomeAssistant, +) -> restore_state.RestoreStateData: + """Mock shutting down and saving restore state and restoring.""" + data = restore_state.async_get(hass) + await data.async_dump_states() + await async_mock_load_restore_state_from_storage(hass) + return data + + +async def async_mock_load_restore_state_from_storage( + hass: HomeAssistant, +) -> None: + """Mock loading restore state from storage. + + hass_storage must already be mocked. + """ + await restore_state.async_get(hass).async_load() + + class MockEntity(entity.Entity): """Mock Entity class.""" diff --git a/tests/components/number/test_init.py b/tests/components/number/test_init.py index 67a02968037..6cd9a53b6f4 100644 --- a/tests/components/number/test_init.py +++ b/tests/components/number/test_init.py @@ -34,7 +34,10 @@ from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY from homeassistant.setup import async_setup_component from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM -from tests.common import mock_restore_cache_with_extra_data +from tests.common import ( + async_mock_restore_state_shutdown_restart, + mock_restore_cache_with_extra_data, +) class MockDefaultNumberEntity(NumberEntity): @@ -635,7 +638,7 @@ async def test_restore_number_save_state( await hass.async_block_till_done() # Trigger saving state - await hass.async_stop() + await async_mock_restore_state_shutdown_restart(hass) assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] diff --git a/tests/components/sensor/test_init.py b/tests/components/sensor/test_init.py index adcbb8084b7..fb079b9ff55 100644 --- a/tests/components/sensor/test_init.py +++ b/tests/components/sensor/test_init.py @@ -35,7 +35,10 @@ from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM -from tests.common import mock_restore_cache_with_extra_data +from tests.common import ( + async_mock_restore_state_shutdown_restart, + mock_restore_cache_with_extra_data, +) @pytest.mark.parametrize( @@ -397,7 +400,7 @@ async def test_restore_sensor_save_state( await hass.async_block_till_done() # Trigger saving state - await hass.async_stop() + await async_mock_restore_state_shutdown_restart(hass) assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] diff --git a/tests/components/text/test_init.py b/tests/components/text/test_init.py index 666ffe18774..d144cc86c91 100644 --- a/tests/components/text/test_init.py +++ b/tests/components/text/test_init.py @@ -20,7 +20,10 @@ from homeassistant.core import HomeAssistant, ServiceCall, State from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY from homeassistant.setup import async_setup_component -from tests.common import mock_restore_cache_with_extra_data +from tests.common import ( + async_mock_restore_state_shutdown_restart, + mock_restore_cache_with_extra_data, +) class MockTextEntity(TextEntity): @@ -141,7 +144,7 @@ async def test_restore_number_save_state( await hass.async_block_till_done() # Trigger saving state - await hass.async_stop() + await async_mock_restore_state_shutdown_restart(hass) assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] diff --git a/tests/components/timer/test_init.py b/tests/components/timer/test_init.py index a60e42eb768..ae700ce08bf 100644 --- a/tests/components/timer/test_init.py +++ b/tests/components/timer/test_init.py @@ -47,11 +47,7 @@ from homeassistant.const import ( from homeassistant.core import Context, CoreState, HomeAssistant, State from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.helpers import config_validation as cv, entity_registry as er -from homeassistant.helpers.restore_state import ( - DATA_RESTORE_STATE_TASK, - RestoreStateData, - StoredState, -) +from homeassistant.helpers.restore_state import StoredState, async_get from homeassistant.setup import async_setup_component from homeassistant.util.dt import utcnow @@ -838,12 +834,9 @@ async def test_restore_idle(hass: HomeAssistant) -> None: utc_now, ) - data = await RestoreStateData.async_get_instance(hass) - await hass.async_block_till_done() + data = async_get(hass) await data.store.async_save([stored_state.as_dict()]) - - # Emulate a fresh load - hass.data.pop(DATA_RESTORE_STATE_TASK) + await data.async_load() entity = Timer.from_storage( { @@ -878,12 +871,9 @@ async def test_restore_paused(hass: HomeAssistant) -> None: utc_now, ) - data = await RestoreStateData.async_get_instance(hass) - await hass.async_block_till_done() + data = async_get(hass) await data.store.async_save([stored_state.as_dict()]) - - # Emulate a fresh load - hass.data.pop(DATA_RESTORE_STATE_TASK) + await data.async_load() entity = Timer.from_storage( { @@ -922,12 +912,9 @@ async def test_restore_active_resume(hass: HomeAssistant) -> None: utc_now, ) - data = await RestoreStateData.async_get_instance(hass) - await hass.async_block_till_done() + data = async_get(hass) await data.store.async_save([stored_state.as_dict()]) - - # Emulate a fresh load - hass.data.pop(DATA_RESTORE_STATE_TASK) + await data.async_load() entity = Timer.from_storage( { @@ -973,12 +960,9 @@ async def test_restore_active_finished_outside_grace(hass: HomeAssistant) -> Non utc_now, ) - data = await RestoreStateData.async_get_instance(hass) - await hass.async_block_till_done() + data = async_get(hass) await data.store.async_save([stored_state.as_dict()]) - - # Emulate a fresh load - hass.data.pop(DATA_RESTORE_STATE_TASK) + await data.async_load() entity = Timer.from_storage( { diff --git a/tests/components/zha/test_binary_sensor.py b/tests/components/zha/test_binary_sensor.py index 2c0461a3c7c..2a30e053376 100644 --- a/tests/components/zha/test_binary_sensor.py +++ b/tests/components/zha/test_binary_sensor.py @@ -21,6 +21,8 @@ from .common import ( ) from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE +from tests.common import async_mock_load_restore_state_from_storage + DEVICE_IAS = { 1: { SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, @@ -186,6 +188,7 @@ async def test_binary_sensor_migration_not_migrated( entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone" core_rs(entity_id, state=restored_state, attributes={}) # migration sensor state + await async_mock_load_restore_state_from_storage(hass) zigpy_device = zigpy_device_mock(DEVICE_IAS) zha_device = await zha_device_restored(zigpy_device) @@ -208,6 +211,7 @@ async def test_binary_sensor_migration_already_migrated( entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone" core_rs(entity_id, state=STATE_OFF, attributes={"migrated_to_cache": True}) + await async_mock_load_restore_state_from_storage(hass) zigpy_device = zigpy_device_mock(DEVICE_IAS) @@ -243,6 +247,7 @@ async def test_onoff_binary_sensor_restore_state( entity_id = "binary_sensor.fakemanufacturer_fakemodel_opening" core_rs(entity_id, state=restored_state, attributes={}) + await async_mock_load_restore_state_from_storage(hass) zigpy_device = zigpy_device_mock(DEVICE_ONOFF) zha_device = await zha_device_restored(zigpy_device) diff --git a/tests/components/zha/test_select.py b/tests/components/zha/test_select.py index 714e27147bb..fb1930e3f99 100644 --- a/tests/components/zha/test_select.py +++ b/tests/components/zha/test_select.py @@ -26,6 +26,8 @@ from homeassistant.util import dt as dt_util from .common import async_enable_traffic, find_entity_id, send_attributes_report from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE +from tests.common import async_mock_load_restore_state_from_storage + @pytest.fixture(autouse=True) def select_select_only(): @@ -176,6 +178,7 @@ async def test_select_restore_state( entity_id = "select.fakemanufacturer_fakemodel_default_siren_tone" core_rs(entity_id, state="Burglar") + await async_mock_load_restore_state_from_storage(hass) zigpy_device = zigpy_device_mock( { diff --git a/tests/components/zha/test_sensor.py b/tests/components/zha/test_sensor.py index 83799147bbe..7d821ced4a0 100644 --- a/tests/components/zha/test_sensor.py +++ b/tests/components/zha/test_sensor.py @@ -47,6 +47,8 @@ from .common import ( ) from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE +from tests.common import async_mock_load_restore_state_from_storage + ENTITY_ID_PREFIX = "sensor.fakemanufacturer_fakemodel_{}" @@ -530,6 +532,7 @@ def core_rs(hass_storage): ], ) async def test_temp_uom( + hass: HomeAssistant, uom, raw_temp, expected, @@ -544,6 +547,7 @@ async def test_temp_uom( entity_id = "sensor.fake1026_fakemodel1026_004f3202_temperature" if restore: core_rs(entity_id, uom, state=(expected - 2)) + await async_mock_load_restore_state_from_storage(hass) hass = await hass_ms( CONF_UNIT_SYSTEM_METRIC diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index 1e8fa8b7fb4..cf6a078d137 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -8,11 +8,13 @@ from homeassistant.core import CoreState, HomeAssistant, State from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity import Entity from homeassistant.helpers.restore_state import ( - DATA_RESTORE_STATE_TASK, + DATA_RESTORE_STATE, STORAGE_KEY, RestoreEntity, RestoreStateData, StoredState, + async_get, + async_load, ) from homeassistant.util import dt as dt_util @@ -28,12 +30,25 @@ async def test_caching_data(hass: HomeAssistant) -> None: StoredState(State("input_boolean.b2", "on"), None, now), ] - data = await RestoreStateData.async_get_instance(hass) + data = async_get(hass) await hass.async_block_till_done() await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load - hass.data.pop(DATA_RESTORE_STATE_TASK) + hass.data.pop(DATA_RESTORE_STATE) + + with patch( + "homeassistant.helpers.restore_state.Store.async_load", + side_effect=HomeAssistantError, + ): + # Failure to load should not be treated as fatal + await async_load(hass) + + data = async_get(hass) + assert data.last_states == {} + + await async_load(hass) + data = async_get(hass) entity = RestoreEntity() entity.hass = hass @@ -55,12 +70,14 @@ async def test_caching_data(hass: HomeAssistant) -> None: async def test_periodic_write(hass: HomeAssistant) -> None: """Test that we write periodiclly but not after stop.""" - data = await RestoreStateData.async_get_instance(hass) + data = async_get(hass) await hass.async_block_till_done() await data.store.async_save([]) # Emulate a fresh load - hass.data.pop(DATA_RESTORE_STATE_TASK) + hass.data.pop(DATA_RESTORE_STATE) + await async_load(hass) + data = async_get(hass) entity = RestoreEntity() entity.hass = hass @@ -101,12 +118,14 @@ async def test_periodic_write(hass: HomeAssistant) -> None: async def test_save_persistent_states(hass: HomeAssistant) -> None: """Test that we cancel the currently running job, save the data, and verify the perdiodic job continues.""" - data = await RestoreStateData.async_get_instance(hass) + data = async_get(hass) await hass.async_block_till_done() await data.store.async_save([]) # Emulate a fresh load - hass.data.pop(DATA_RESTORE_STATE_TASK) + hass.data.pop(DATA_RESTORE_STATE) + await async_load(hass) + data = async_get(hass) entity = RestoreEntity() entity.hass = hass @@ -166,13 +185,15 @@ async def test_hass_starting(hass: HomeAssistant) -> None: StoredState(State("input_boolean.b2", "on"), None, now), ] - data = await RestoreStateData.async_get_instance(hass) + data = async_get(hass) await hass.async_block_till_done() await data.store.async_save([state.as_dict() for state in stored_states]) # Emulate a fresh load hass.state = CoreState.not_running - hass.data.pop(DATA_RESTORE_STATE_TASK) + hass.data.pop(DATA_RESTORE_STATE) + await async_load(hass) + data = async_get(hass) entity = RestoreEntity() entity.hass = hass @@ -223,7 +244,7 @@ async def test_dump_data(hass: HomeAssistant) -> None: entity.entity_id = "input_boolean.b1" await entity.async_internal_added_to_hass() - data = await RestoreStateData.async_get_instance(hass) + data = async_get(hass) now = dt_util.utcnow() data.last_states = { "input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now), @@ -297,7 +318,7 @@ async def test_dump_error(hass: HomeAssistant) -> None: entity.entity_id = "input_boolean.b1" await entity.async_internal_added_to_hass() - data = await RestoreStateData.async_get_instance(hass) + data = async_get(hass) with patch( "homeassistant.helpers.restore_state.Store.async_save", @@ -335,7 +356,7 @@ async def test_state_saved_on_remove(hass: HomeAssistant) -> None: "input_boolean.b0", "on", {"complicated": {"value": {1, 2, now}}} ) - data = await RestoreStateData.async_get_instance(hass) + data = async_get(hass) # No last states should currently be saved assert not data.last_states