Avoid creating inner tasks to load storage (#117099)

This commit is contained in:
J. Nick Koston 2024-05-08 16:41:20 -05:00 committed by GitHub
parent ead69af27c
commit 03dcede211
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 16 deletions

View File

@ -254,7 +254,7 @@ class Store(Generic[_T]):
self._delay_handle: asyncio.TimerHandle | None = None self._delay_handle: asyncio.TimerHandle | None = None
self._unsub_final_write_listener: CALLBACK_TYPE | None = None self._unsub_final_write_listener: CALLBACK_TYPE | None = None
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock()
self._load_task: asyncio.Future[_T | None] | None = None self._load_future: asyncio.Future[_T | None] | None = None
self._encoder = encoder self._encoder = encoder
self._atomic_writes = atomic_writes self._atomic_writes = atomic_writes
self._read_only = read_only self._read_only = read_only
@ -276,27 +276,32 @@ class Store(Generic[_T]):
Will ensure that when a call comes in while another one is in progress, Will ensure that when a call comes in while another one is in progress,
the second call will wait and return the result of the first call. the second call will wait and return the result of the first call.
""" """
if self._load_task: if self._load_future:
return await self._load_task return await self._load_future
load_task = self.hass.async_create_background_task( self._load_future = self.hass.loop.create_future()
self._async_load(), f"Storage load {self.key}", eager_start=True try:
) result = await self._async_load()
if not load_task.done(): except BaseException as ex:
# Only set the load task if it didn't complete immediately self._load_future.set_exception(ex)
self._load_task = load_task # Ensure the future is marked as retrieved
return await load_task # since if there is no concurrent call it
# will otherwise never be retrieved.
self._load_future.exception()
raise
else:
self._load_future.set_result(result)
finally:
self._load_future = None
return result
async def _async_load(self) -> _T | None: async def _async_load(self) -> _T | None:
"""Load the data and ensure the task is removed.""" """Load the data and ensure the task is removed."""
if STORAGE_SEMAPHORE not in self.hass.data: if STORAGE_SEMAPHORE not in self.hass.data:
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY) self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
try:
async with self.hass.data[STORAGE_SEMAPHORE]: async with self.hass.data[STORAGE_SEMAPHORE]:
return await self._async_load_data() return await self._async_load_data()
finally:
self._load_task = None
async def _async_load_data(self): async def _async_load_data(self):
"""Load the data.""" """Load the data."""

View File

@ -1159,3 +1159,21 @@ async def test_store_manager_cleanup_after_stop(
assert store_manager.async_fetch("integration1") is None assert store_manager.async_fetch("integration1") is None
assert store_manager.async_fetch("integration2") is None assert store_manager.async_fetch("integration2") is None
await hass.async_stop(force=True) await hass.async_stop(force=True)
async def test_storage_concurrent_load(hass: HomeAssistant) -> None:
"""Test that we can load the store concurrently."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
async def _load_store():
await asyncio.sleep(0)
return "data"
with patch.object(store, "_async_load", side_effect=_load_store):
# Test that we can load the store concurrently
loads = await asyncio.gather(
store.async_load(), store.async_load(), store.async_load()
)
for load in loads:
assert load == "data"