mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Make RestoreStateData.async_get_instance backwards compatible (#93924)
This commit is contained in:
parent
5a8daf06e8
commit
457bc4571d
@ -16,6 +16,7 @@ import homeassistant.util.dt as dt_util
|
|||||||
from . import start
|
from . import start
|
||||||
from .entity import Entity
|
from .entity import Entity
|
||||||
from .event import async_track_time_interval
|
from .event import async_track_time_interval
|
||||||
|
from .frame import report
|
||||||
from .json import JSONEncoder
|
from .json import JSONEncoder
|
||||||
from .storage import Store
|
from .storage import Store
|
||||||
|
|
||||||
@ -96,7 +97,9 @@ class StoredState:
|
|||||||
|
|
||||||
async def async_load(hass: HomeAssistant) -> None:
|
async def async_load(hass: HomeAssistant) -> None:
|
||||||
"""Load the restore state task."""
|
"""Load the restore state task."""
|
||||||
hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass)
|
restore_state = RestoreStateData(hass)
|
||||||
|
await restore_state.async_setup()
|
||||||
|
hass.data[DATA_RESTORE_STATE] = restore_state
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -108,25 +111,26 @@ def async_get(hass: HomeAssistant) -> RestoreStateData:
|
|||||||
class RestoreStateData:
|
class RestoreStateData:
|
||||||
"""Helper class for managing the helper saved data."""
|
"""Helper class for managing the helper saved data."""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def async_get_instance(hass: HomeAssistant) -> RestoreStateData:
|
|
||||||
"""Get the instance of this data helper."""
|
|
||||||
data = RestoreStateData(hass)
|
|
||||||
await data.async_load()
|
|
||||||
|
|
||||||
async def hass_start(hass: HomeAssistant) -> None:
|
|
||||||
"""Start the restore state task."""
|
|
||||||
data.async_setup_dump()
|
|
||||||
|
|
||||||
start.async_at_start(hass, hass_start)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
|
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
|
||||||
"""Dump states now."""
|
"""Dump states now."""
|
||||||
await async_get(hass).async_dump_states()
|
await async_get(hass).async_dump_states()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def async_get_instance(cls, hass: HomeAssistant) -> RestoreStateData:
|
||||||
|
"""Return the instance of this class."""
|
||||||
|
# Nothing should actually be calling this anymore, but we'll keep it
|
||||||
|
# around for a while to avoid breaking custom components.
|
||||||
|
#
|
||||||
|
# In fact they should not be accessing this at all.
|
||||||
|
report(
|
||||||
|
"restore_state.RestoreStateData.async_get_instance is deprecated, "
|
||||||
|
"and not intended to be called by custom components; Please"
|
||||||
|
"refactor your code to use RestoreEntity instead;"
|
||||||
|
" restore_state.async_get(hass) can be used in the meantime",
|
||||||
|
)
|
||||||
|
return async_get(hass)
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the restore state data class."""
|
"""Initialize the restore state data class."""
|
||||||
self.hass: HomeAssistant = hass
|
self.hass: HomeAssistant = hass
|
||||||
@ -136,6 +140,16 @@ class RestoreStateData:
|
|||||||
self.last_states: dict[str, StoredState] = {}
|
self.last_states: dict[str, StoredState] = {}
|
||||||
self.entities: dict[str, RestoreEntity] = {}
|
self.entities: dict[str, RestoreEntity] = {}
|
||||||
|
|
||||||
|
async def async_setup(self) -> None:
|
||||||
|
"""Set up up the instance of this data helper."""
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
async def hass_start(hass: HomeAssistant) -> None:
|
||||||
|
"""Start the restore state task."""
|
||||||
|
self.async_setup_dump()
|
||||||
|
|
||||||
|
start.async_at_start(self.hass, hass_start)
|
||||||
|
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
"""Load the instance of this data helper."""
|
"""Load the instance of this data helper."""
|
||||||
try:
|
try:
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
"""The tests for the Restore component."""
|
"""The tests for the Restore component."""
|
||||||
|
from collections.abc import Coroutine
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
|
||||||
from homeassistant.core import CoreState, HomeAssistant, State
|
from homeassistant.core import CoreState, HomeAssistant, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.helpers.reload import async_get_platform_without_config_entry
|
||||||
from homeassistant.helpers.restore_state import (
|
from homeassistant.helpers.restore_state import (
|
||||||
DATA_RESTORE_STATE,
|
DATA_RESTORE_STATE,
|
||||||
STORAGE_KEY,
|
STORAGE_KEY,
|
||||||
@ -16,9 +23,20 @@ from homeassistant.helpers.restore_state import (
|
|||||||
async_get,
|
async_get,
|
||||||
async_load,
|
async_load,
|
||||||
)
|
)
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed
|
from tests.common import (
|
||||||
|
MockModule,
|
||||||
|
MockPlatform,
|
||||||
|
async_fire_time_changed,
|
||||||
|
mock_entity_platform,
|
||||||
|
mock_integration,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
DOMAIN = "test_domain"
|
||||||
|
PLATFORM = "test_platform"
|
||||||
|
|
||||||
|
|
||||||
async def test_caching_data(hass: HomeAssistant) -> None:
|
async def test_caching_data(hass: HomeAssistant) -> None:
|
||||||
@ -68,6 +86,20 @@ async def test_caching_data(hass: HomeAssistant) -> None:
|
|||||||
assert mock_write_data.called
|
assert mock_write_data.called
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_instance_backwards_compatibility(hass: HomeAssistant) -> None:
|
||||||
|
"""Test async_get_instance backwards compatibility."""
|
||||||
|
await async_load(hass)
|
||||||
|
data = async_get(hass)
|
||||||
|
# When called from core it should raise
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await RestoreStateData.async_get_instance(hass)
|
||||||
|
|
||||||
|
# When called from a component it should not raise
|
||||||
|
# but it should report
|
||||||
|
with patch("homeassistant.helpers.restore_state.report"):
|
||||||
|
assert data is await RestoreStateData.async_get_instance(hass)
|
||||||
|
|
||||||
|
|
||||||
async def test_periodic_write(hass: HomeAssistant) -> None:
|
async def test_periodic_write(hass: HomeAssistant) -> None:
|
||||||
"""Test that we write periodiclly but not after stop."""
|
"""Test that we write periodiclly but not after stop."""
|
||||||
data = async_get(hass)
|
data = async_get(hass)
|
||||||
@ -401,3 +433,89 @@ async def test_restoring_invalid_entity_id(
|
|||||||
|
|
||||||
state = await entity.async_get_last_state()
|
state = await entity.async_get_last_state()
|
||||||
assert state is None
|
assert state is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_restore_entity_end_to_end(
|
||||||
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test restoring an entity end-to-end."""
|
||||||
|
component_setup = Mock(return_value=True)
|
||||||
|
|
||||||
|
setup_called = []
|
||||||
|
|
||||||
|
entity_id = "test_domain.unnamed_device"
|
||||||
|
data = async_get(hass)
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
data.last_states = {
|
||||||
|
entity_id: StoredState(State(entity_id, "stored"), None, now),
|
||||||
|
}
|
||||||
|
|
||||||
|
class MockRestoreEntity(RestoreEntity):
|
||||||
|
"""Mock restore entity."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the mock entity."""
|
||||||
|
self._state: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
"""Return the state."""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> Coroutine[Any, Any, None]:
|
||||||
|
"""Run when entity about to be added to hass."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
self._state = (await self.async_get_last_state()).state
|
||||||
|
|
||||||
|
async def async_setup_platform(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set up the test platform."""
|
||||||
|
async_add_entities([MockRestoreEntity()])
|
||||||
|
setup_called.append(True)
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule(DOMAIN, setup=component_setup))
|
||||||
|
mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))
|
||||||
|
|
||||||
|
mock_platform = MockPlatform(async_setup_platform=async_setup_platform)
|
||||||
|
mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform)
|
||||||
|
|
||||||
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||||
|
|
||||||
|
await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert component_setup.called
|
||||||
|
|
||||||
|
assert f"{DOMAIN}.{PLATFORM}" in hass.config.components
|
||||||
|
assert len(setup_called) == 1
|
||||||
|
|
||||||
|
platform = async_get_platform_without_config_entry(hass, PLATFORM, DOMAIN)
|
||||||
|
assert platform.platform_name == PLATFORM
|
||||||
|
assert platform.domain == DOMAIN
|
||||||
|
assert hass.states.get(entity_id).state == "stored"
|
||||||
|
|
||||||
|
await data.async_dump_states()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
storage_data = hass_storage[STORAGE_KEY]["data"]
|
||||||
|
assert len(storage_data) == 1
|
||||||
|
assert storage_data[0]["state"]["entity_id"] == entity_id
|
||||||
|
assert storage_data[0]["state"]["state"] == "stored"
|
||||||
|
|
||||||
|
await platform.async_reset()
|
||||||
|
|
||||||
|
assert hass.states.get(entity_id) is None
|
||||||
|
|
||||||
|
# Make sure the entity still gets saved to restore state
|
||||||
|
# even though the platform has been reset since it should
|
||||||
|
# not be expired yet.
|
||||||
|
await data.async_dump_states()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
storage_data = hass_storage[STORAGE_KEY]["data"]
|
||||||
|
assert len(storage_data) == 1
|
||||||
|
assert storage_data[0]["state"]["entity_id"] == entity_id
|
||||||
|
assert storage_data[0]["state"]["state"] == "stored"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user