async_get_instance was not reentrant during await (#38263)

This commit is contained in:
Joakim Plate 2020-08-12 22:35:24 +02:00 committed by GitHub
parent 8cf0a01149
commit 15db2225da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 39 deletions

View File

@ -2,7 +2,7 @@
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from typing import Any, Awaitable, Dict, List, Optional, Set, cast from typing import Any, Dict, List, Optional, Set, cast
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import ( from homeassistant.core import (
@ -17,6 +17,7 @@ from homeassistant.helpers import entity_registry
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -63,45 +64,39 @@ class RestoreStateData:
@classmethod @classmethod
async def async_get_instance(cls, hass: HomeAssistant) -> "RestoreStateData": async def async_get_instance(cls, hass: HomeAssistant) -> "RestoreStateData":
"""Get the singleton instance of this data helper.""" """Get the singleton instance of this data helper."""
task = hass.data.get(DATA_RESTORE_STATE_TASK)
if task is None: @singleton(DATA_RESTORE_STATE_TASK)
async def load_instance(hass: HomeAssistant) -> "RestoreStateData":
"""Get the singleton instance of this data helper."""
data = cls(hass)
async def load_instance(hass: HomeAssistant) -> "RestoreStateData": try:
"""Set up the restore state helper.""" stored_states = await data.store.async_load()
data = cls(hass) except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
stored_states = None
try: if stored_states is None:
stored_states = await data.store.async_load() _LOGGER.debug("Not creating cache - no saved states found")
except HomeAssistantError as exc: data.last_states = {}
_LOGGER.error("Error loading last states", exc_info=exc) else:
stored_states = None data.last_states = {
item["state"]["entity_id"]: StoredState.from_dict(item)
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(data.last_states))
if stored_states is None: if hass.state == CoreState.running:
_LOGGER.debug("Not creating cache - no saved states found") data.async_setup_dump()
data.last_states = {} else:
else: hass.bus.async_listen_once(
data.last_states = { EVENT_HOMEASSISTANT_START, data.async_setup_dump
item["state"]["entity_id"]: StoredState.from_dict(item) )
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(data.last_states))
if hass.state == CoreState.running: return data
data.async_setup_dump()
else:
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, data.async_setup_dump
)
return data return cast(RestoreStateData, await load_instance(hass))
task = hass.data[DATA_RESTORE_STATE_TASK] = hass.async_create_task(
load_instance(hass)
)
return await cast(Awaitable["RestoreStateData"], task)
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the restore state data class.""" """Initialize the restore state data class."""

View File

@ -818,11 +818,7 @@ def mock_restore_cache(hass, states):
_LOGGER.debug("Restore cache: %s", data.last_states) _LOGGER.debug("Restore cache: %s", data.last_states)
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
async def get_restore_state_data() -> restore_state.RestoreStateData: hass.data[key] = data
return data
# Patch the singleton task in hass.data to return our new RestoreStateData
hass.data[key] = hass.async_create_task(get_restore_state_data())
class MockEntity(entity.Entity): class MockEntity(entity.Entity):

View File

@ -27,6 +27,7 @@ async def test_caching_data(hass):
] ]
data = await RestoreStateData.async_get_instance(hass) data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
await data.store.async_save([state.as_dict() for state in stored_states]) await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load # Emulate a fresh load
@ -41,6 +42,7 @@ async def test_caching_data(hass):
"homeassistant.helpers.restore_state.Store.async_save" "homeassistant.helpers.restore_state.Store.async_save"
) as mock_write_data: ) as mock_write_data:
state = await entity.async_get_last_state() state = await entity.async_get_last_state()
await hass.async_block_till_done()
assert state is not None assert state is not None
assert state.entity_id == "input_boolean.b1" assert state.entity_id == "input_boolean.b1"
@ -61,6 +63,7 @@ async def test_hass_starting(hass):
] ]
data = await RestoreStateData.async_get_instance(hass) data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
await data.store.async_save([state.as_dict() for state in stored_states]) await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load # Emulate a fresh load
@ -76,6 +79,7 @@ async def test_hass_starting(hass):
"homeassistant.helpers.restore_state.Store.async_save" "homeassistant.helpers.restore_state.Store.async_save"
) as mock_write_data, patch.object(hass.states, "async_all", return_value=states): ) as mock_write_data, patch.object(hass.states, "async_all", return_value=states):
state = await entity.async_get_last_state() state = await entity.async_get_last_state()
await hass.async_block_till_done()
assert state is not None assert state is not None
assert state.entity_id == "input_boolean.b1" assert state.entity_id == "input_boolean.b1"