mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 18:27:51 +00:00
Avoid creating inner tasks to load storage (#117099)
This commit is contained in:
parent
ead69af27c
commit
03dcede211
@ -254,7 +254,7 @@ class Store(Generic[_T]):
|
|||||||
self._delay_handle: asyncio.TimerHandle | None = None
|
self._delay_handle: asyncio.TimerHandle | None = None
|
||||||
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
|
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
|
||||||
self._write_lock = asyncio.Lock()
|
self._write_lock = asyncio.Lock()
|
||||||
self._load_task: asyncio.Future[_T | None] | None = None
|
self._load_future: asyncio.Future[_T | None] | None = None
|
||||||
self._encoder = encoder
|
self._encoder = encoder
|
||||||
self._atomic_writes = atomic_writes
|
self._atomic_writes = atomic_writes
|
||||||
self._read_only = read_only
|
self._read_only = read_only
|
||||||
@ -276,27 +276,32 @@ class Store(Generic[_T]):
|
|||||||
Will ensure that when a call comes in while another one is in progress,
|
Will ensure that when a call comes in while another one is in progress,
|
||||||
the second call will wait and return the result of the first call.
|
the second call will wait and return the result of the first call.
|
||||||
"""
|
"""
|
||||||
if self._load_task:
|
if self._load_future:
|
||||||
return await self._load_task
|
return await self._load_future
|
||||||
|
|
||||||
load_task = self.hass.async_create_background_task(
|
self._load_future = self.hass.loop.create_future()
|
||||||
self._async_load(), f"Storage load {self.key}", eager_start=True
|
try:
|
||||||
)
|
result = await self._async_load()
|
||||||
if not load_task.done():
|
except BaseException as ex:
|
||||||
# Only set the load task if it didn't complete immediately
|
self._load_future.set_exception(ex)
|
||||||
self._load_task = load_task
|
# Ensure the future is marked as retrieved
|
||||||
return await load_task
|
# since if there is no concurrent call it
|
||||||
|
# will otherwise never be retrieved.
|
||||||
|
self._load_future.exception()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
self._load_future.set_result(result)
|
||||||
|
finally:
|
||||||
|
self._load_future = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def _async_load(self) -> _T | None:
|
async def _async_load(self) -> _T | None:
|
||||||
"""Load the data and ensure the task is removed."""
|
"""Load the data and ensure the task is removed."""
|
||||||
if STORAGE_SEMAPHORE not in self.hass.data:
|
if STORAGE_SEMAPHORE not in self.hass.data:
|
||||||
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
|
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.hass.data[STORAGE_SEMAPHORE]:
|
async with self.hass.data[STORAGE_SEMAPHORE]:
|
||||||
return await self._async_load_data()
|
return await self._async_load_data()
|
||||||
finally:
|
|
||||||
self._load_task = None
|
|
||||||
|
|
||||||
async def _async_load_data(self):
|
async def _async_load_data(self):
|
||||||
"""Load the data."""
|
"""Load the data."""
|
||||||
|
@ -1159,3 +1159,21 @@ async def test_store_manager_cleanup_after_stop(
|
|||||||
assert store_manager.async_fetch("integration1") is None
|
assert store_manager.async_fetch("integration1") is None
|
||||||
assert store_manager.async_fetch("integration2") is None
|
assert store_manager.async_fetch("integration2") is None
|
||||||
await hass.async_stop(force=True)
|
await hass.async_stop(force=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_storage_concurrent_load(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that we can load the store concurrently."""
|
||||||
|
|
||||||
|
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||||
|
|
||||||
|
async def _load_store():
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
return "data"
|
||||||
|
|
||||||
|
with patch.object(store, "_async_load", side_effect=_load_store):
|
||||||
|
# Test that we can load the store concurrently
|
||||||
|
loads = await asyncio.gather(
|
||||||
|
store.async_load(), store.async_load(), store.async_load()
|
||||||
|
)
|
||||||
|
for load in loads:
|
||||||
|
assert load == "data"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user