Restore state helper to work with entity registry restoration (#30451)

* Restore state helper to work with entity registry restoratino

* Update restore_state.py
This commit is contained in:
Paulus Schoutsen 2020-01-05 11:58:59 +01:00 committed by GitHub
parent 2ac5862eda
commit 24b25b8917
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 5 deletions

View File

@ -13,6 +13,7 @@ from homeassistant.core import (
valid_entity_id, valid_entity_id,
) )
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
@ -123,19 +124,27 @@ class RestoreStateData:
""" """
now = dt_util.utcnow() now = dt_util.utcnow()
all_states = self.hass.states.async_all() all_states = self.hass.states.async_all()
current_entity_ids = set(state.entity_id for state in all_states) # Entities currently backed by an entity object
current_entity_ids = set(
state.entity_id
for state in all_states
if not state.attributes.get(entity_registry.ATTR_RESTORED)
)
# Start with the currently registered states # Start with the currently registered states
stored_states = [ stored_states = [
StoredState(state, now) StoredState(state, now)
for state in all_states for state in all_states
if state.entity_id in self.entity_ids if state.entity_id in self.entity_ids and
# Ignore all states that are entity registry placeholders
not state.attributes.get(entity_registry.ATTR_RESTORED)
] ]
expiration_time = now - STATE_EXPIRATION expiration_time = now - STATE_EXPIRATION
for entity_id, stored_state in self.last_states.items(): for entity_id, stored_state in self.last_states.items():
# Don't save old states that have entities in the current run # Don't save old states that have entities in the current run
# They are either registered and already part of stored_states,
# or no longer care about restoring.
if entity_id in current_entity_ids: if entity_id in current_entity_ids:
continue continue

View File

@ -103,6 +103,7 @@ async def test_dump_data(hass):
State("input_boolean.b0", "on"), State("input_boolean.b0", "on"),
State("input_boolean.b1", "on"), State("input_boolean.b1", "on"),
State("input_boolean.b2", "on"), State("input_boolean.b2", "on"),
State("input_boolean.b5", "unavailable", {"restored": True}),
] ]
entity = Entity() entity = Entity()
@ -126,6 +127,7 @@ async def test_dump_data(hass):
State("input_boolean.b4", "off"), State("input_boolean.b4", "off"),
datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC), datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC),
), ),
"input_boolean.b5": StoredState(State("input_boolean.b5", "off"), now),
} }
with patch( with patch(
@ -142,11 +144,14 @@ async def test_dump_data(hass):
# b2 should not be written, since it is not registered with the helper # b2 should not be written, since it is not registered with the helper
# b3 should be written, since it is still not expired # b3 should be written, since it is still not expired
# b4 should not be written, since it is now expired # b4 should not be written, since it is now expired
assert len(written_states) == 2 # b5 should be written, since current state is restored by entity registry
assert len(written_states) == 3
assert written_states[0]["state"]["entity_id"] == "input_boolean.b1" assert written_states[0]["state"]["entity_id"] == "input_boolean.b1"
assert written_states[0]["state"]["state"] == "on" assert written_states[0]["state"]["state"] == "on"
assert written_states[1]["state"]["entity_id"] == "input_boolean.b3" assert written_states[1]["state"]["entity_id"] == "input_boolean.b3"
assert written_states[1]["state"]["state"] == "off" assert written_states[1]["state"]["state"] == "off"
assert written_states[2]["state"]["entity_id"] == "input_boolean.b5"
assert written_states[2]["state"]["state"] == "off"
# Test that removed entities are not persisted # Test that removed entities are not persisted
await entity.async_remove() await entity.async_remove()
@ -159,9 +164,11 @@ async def test_dump_data(hass):
assert mock_write_data.called assert mock_write_data.called
args = mock_write_data.mock_calls[0][1] args = mock_write_data.mock_calls[0][1]
written_states = args[0] written_states = args[0]
assert len(written_states) == 1 assert len(written_states) == 2
assert written_states[0]["state"]["entity_id"] == "input_boolean.b3" assert written_states[0]["state"]["entity_id"] == "input_boolean.b3"
assert written_states[0]["state"]["state"] == "off" assert written_states[0]["state"]["state"] == "off"
assert written_states[1]["state"]["entity_id"] == "input_boolean.b5"
assert written_states[1]["state"]["state"] == "off"
async def test_dump_error(hass): async def test_dump_error(hass):