Avoid creating multiple tasks for config entry init (#110899)

This commit is contained in:
J. Nick Koston 2024-02-20 20:57:36 -06:00 committed by GitHub
parent 9ce1ec414e
commit 2f2cdedddd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 16 deletions

View File

@ -22,6 +22,8 @@ from random import randint
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Self, TypeVar, cast
from async_interrupt import interrupt
from . import data_entry_flow, loader
from .components import persistent_notification
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
@ -948,6 +950,10 @@ current_entry: ContextVar[ConfigEntry | None] = ContextVar(
)
class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
"""Manage all the config entry flows that are in progress."""
@ -962,7 +968,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self.config_entries = config_entries
self._hass_config = hass_config
self._pending_import_flows: dict[str, dict[str, asyncio.Future[None]]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
self._initialize_futures: dict[str, list[asyncio.Future[None]]] = {}
self._discovery_debouncer = Debouncer(
hass,
_LOGGER,
@ -994,20 +1000,26 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
raise KeyError("Context not set or doesn't have a source set")
flow_id = uuid_util.random_uuid_hex()
loop = self.hass.loop
if context["source"] == SOURCE_IMPORT:
init_done: asyncio.Future[None] = self.hass.loop.create_future()
self._pending_import_flows.setdefault(handler, {})[flow_id] = init_done
task = asyncio.create_task(
self._async_init(flow_id, handler, context, data),
name=f"config entry flow {handler} {flow_id}",
)
self._initialize_tasks.setdefault(handler, []).append(task)
self._pending_import_flows.setdefault(handler, {})[
flow_id
] = loop.create_future()
cancel_init_future = loop.create_future()
self._initialize_futures.setdefault(handler, []).append(cancel_init_future)
try:
flow, result = await task
async with interrupt(
cancel_init_future,
FlowCancelledError,
"Config entry initialize canceled: Home Assistant is shutting down",
):
flow, result = await self._async_init(flow_id, handler, context, data)
except FlowCancelledError as ex:
raise asyncio.CancelledError from ex
finally:
self._initialize_tasks[handler].remove(task)
self._initialize_futures[handler].remove(cancel_init_future)
self._pending_import_flows.get(handler, {}).pop(flow_id, None)
if result["type"] != data_entry_flow.FlowResultType.ABORT:
@ -1042,11 +1054,9 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
async def async_shutdown(self) -> None:
"""Cancel any initializing flows."""
for task_list in self._initialize_tasks.values():
for task in task_list:
task.cancel(
"Config entry initialize canceled: Home Assistant is shutting down"
)
for future_list in self._initialize_futures.values():
for future in future_list:
future.set_result(None)
await self._discovery_debouncer.async_shutdown()
async def async_finish_flow(

View File

@ -7,6 +7,7 @@ aiohttp-zlib-ng==0.3.1
aiohttp==3.9.3
aiohttp_cors==0.7.0
astral==2.2
async-interrupt==1.1.1
async-upnp-client==0.38.2
atomicwrites-homeassistant==1.4.1
attrs==23.2.0

View File

@ -28,6 +28,7 @@ dependencies = [
"aiohttp-fast-url-dispatcher==0.3.0",
"aiohttp-zlib-ng==0.3.1",
"astral==2.2",
"async-interrupt==1.1.1",
"attrs==23.2.0",
"atomicwrites-homeassistant==1.4.1",
"awesomeversion==24.2.0",

View File

@ -8,6 +8,7 @@ aiohttp_cors==0.7.0
aiohttp-fast-url-dispatcher==0.3.0
aiohttp-zlib-ng==0.3.1
astral==2.2
async-interrupt==1.1.1
attrs==23.2.0
atomicwrites-homeassistant==1.4.1
awesomeversion==24.2.0